Skip to content

Predict from a brulee_linear_reg

Usage

# S3 method for brulee_linear_reg
predict(object, new_data, type = NULL, epoch = NULL, ...)

Arguments

object

A brulee_linear_reg object.

new_data

A data frame or matrix of new predictors.

type

A single character. The type of predictions to generate. Valid options are:

  • "numeric" for numeric predictions.

epoch

An integer for the epoch to make predictions. If this value is larger than the maximum number that was fit, a warning is issued and the parameters from the last epoch are used. If left NULL, the epoch associated with the smallest loss is used.

...

Not used, but required for extensibility.

Value

A tibble of predictions. The number of rows in the tibble is guaranteed to be the same as the number of rows in new_data.

Examples

# \donttest{
if (torch::torch_is_installed()) {

 data(ames, package = "modeldata")

 ames$Sale_Price <- log10(ames$Sale_Price)

 set.seed(1)
 in_train <- sample(1:nrow(ames), 2000)
 ames_train <- ames[ in_train,]
 ames_test  <- ames[-in_train,]

 # Using recipe
 library(recipes)

 ames_rec <-
  recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) %>%
    step_normalize(all_numeric_predictors())

 set.seed(2)
 fit <- brulee_linear_reg(ames_rec, data = ames_train,
                           epochs = 50, batch_size = 32)

 predict(fit, ames_test)
}
#> # A tibble: 930 × 1
#>    .pred
#>    <dbl>
#>  1  5.27
#>  2  5.37
#>  3  5.37
#>  4  5.33
#>  5  5.33
#>  6  5.30
#>  7  5.30
#>  8  5.30
#>  9  5.31
#> 10  5.29
#> # … with 920 more rows
# }