Predict from a brulee_rln
Usage
# S3 method for class 'brulee_rln'
predict(object, new_data, type = NULL, epoch = NULL, ...)Arguments
- object
A
brulee_rlnobject.- new_data
A data frame or matrix of new predictors.
- type
A single character. The only valid option is
"numeric"for numeric predictions.- epoch
An integer for the epoch to make predictions. If larger than the number of epochs fit, a warning is issued and the last epoch is used. If
NULL(default), the epoch with the smallest loss is used.- ...
Not used, but required for extensibility.
Examples
# \donttest{
if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "modeldata"))) {
data(ames, package = "modeldata")
ames$Sale_Price <- log10(ames$Sale_Price)
set.seed(1)
in_train <- sample(1:nrow(ames), 2000)
ames_train <- ames[ in_train,]
ames_test <- ames[-in_train,]
library(recipes)
ames_rec <-
recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |>
step_normalize(all_numeric_predictors())
set.seed(2)
fit <- brulee_rln(ames_rec, data = ames_train, hidden_units = 20L, epochs = 30L)
predict(fit, ames_test)
}
#> # A tibble: 930 × 1
#> .pred
#> <dbl>
#> 1 5.19
#> 2 5.36
#> 3 5.36
#> 4 5.27
#> 5 5.29
#> 6 5.24
#> 7 5.22
#> 8 5.22
#> 9 5.24
#> 10 5.21
#> # ℹ 920 more rows
# }
