Skip to content

Predict from a brulee_mlp

Usage

# S3 method for brulee_mlp
predict(object, new_data, type = NULL, epoch = NULL, ...)

Arguments

object

A brulee_mlp object.

new_data

A data frame or matrix of new predictors.

type

A single character. The type of predictions to generate. Valid options are:

  • "numeric" for numeric predictions.

  • "class" for hard class predictions

  • "prob" for soft class predictions (i.e., class probabilities)

epoch

An integer for the epoch to make predictions. If this value is larger than the maximum number that was fit, a warning is issued and the parameters from the last epoch are used. If left NULL, the epoch associated with the smallest loss is used.

...

Not used, but required for extensibility.

Value

A tibble of predictions. The number of rows in the tibble is guaranteed to be the same as the number of rows in new_data.

Examples

# \donttest{
if (torch::torch_is_installed()) {
 # regression example:

 data(ames, package = "modeldata")

 ames$Sale_Price <- log10(ames$Sale_Price)

 set.seed(1)
 in_train <- sample(1:nrow(ames), 2000)
 ames_train <- ames[ in_train,]
 ames_test  <- ames[-in_train,]

 # Using recipe
 library(recipes)

 ames_rec <-
  recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) %>%
    step_normalize(all_numeric_predictors())

 set.seed(2)
 fit <- brulee_mlp(ames_rec, data = ames_train, epochs = 50, batch_size = 32)

 predict(fit, ames_test)
}
#> Warning: 'batch_size' is only use for the SGD optimizer.
#> # A tibble: 930 × 1
#>    .pred
#>    <dbl>
#>  1  5.16
#>  2  5.28
#>  3  5.28
#>  4  5.22
#>  5  5.27
#>  6  5.25
#>  7  5.20
#>  8  5.20
#>  9  5.22
#> 10  5.21
#> # … with 920 more rows
# }