brulee_saint() fits the SAINT (Self-Attention and Inter-sample Attention
Transformer) model from Somepalli et al (2021). SAINT applies multi-head
self-attention across both features (column attention) and samples within a
batch (row/inter-sample attention) to learn complex feature interactions.
Usage
brulee_saint(x, ...)
# Default S3 method
brulee_saint(x, ...)
# S3 method for class 'data.frame'
brulee_saint(
x,
y,
epochs = 100L,
num_embedding = 32L,
attention_type = "both",
num_attn_heads = 8L,
num_attn_blocks = 6L,
dropout_attn = 0.1,
dropout_hidden = 0.1,
dropout_last = 0,
row_attention_on_predict = TRUE,
hidden_units = 5,
hidden_activations = "relu",
penalty = 0.001,
mixture = 0,
validation = 0.1,
optimizer = "ADAMw",
learn_rate = 1e-04,
rate_schedule = "none",
momentum = 0,
batch_size = NULL,
class_weights = NULL,
stop_iter = 5,
verbose = FALSE,
device = NULL,
use_target_token = TRUE,
...
)
# S3 method for class 'matrix'
brulee_saint(
x,
y,
epochs = 100L,
num_embedding = 32L,
attention_type = "both",
num_attn_heads = 8L,
num_attn_blocks = 6L,
dropout_attn = 0.1,
dropout_hidden = 0.1,
dropout_last = 0,
row_attention_on_predict = TRUE,
hidden_units = 5,
hidden_activations = "relu",
penalty = 0.001,
mixture = 0,
validation = 0.1,
optimizer = "ADAMw",
learn_rate = 1e-04,
rate_schedule = "none",
momentum = 0,
batch_size = NULL,
class_weights = NULL,
stop_iter = 5,
verbose = FALSE,
device = NULL,
use_target_token = TRUE,
...
)
# S3 method for class 'formula'
brulee_saint(
formula,
data,
epochs = 100L,
num_embedding = 32L,
attention_type = "both",
num_attn_heads = 8L,
num_attn_blocks = 6L,
dropout_attn = 0.1,
dropout_hidden = 0.1,
dropout_last = 0,
row_attention_on_predict = TRUE,
hidden_units = 5,
hidden_activations = "relu",
penalty = 0.001,
mixture = 0,
validation = 0.1,
optimizer = "ADAMw",
learn_rate = 1e-04,
rate_schedule = "none",
momentum = 0,
batch_size = NULL,
class_weights = NULL,
stop_iter = 5,
verbose = FALSE,
device = NULL,
use_target_token = TRUE,
...
)
# S3 method for class 'recipe'
brulee_saint(
x,
data,
epochs = 100L,
num_embedding = 32L,
attention_type = "both",
num_attn_heads = 8L,
num_attn_blocks = 6L,
dropout_attn = 0.1,
dropout_hidden = 0.1,
dropout_last = 0,
row_attention_on_predict = TRUE,
hidden_units = 5,
hidden_activations = "relu",
penalty = 0.001,
mixture = 0,
validation = 0.1,
optimizer = "ADAMw",
learn_rate = 1e-04,
rate_schedule = "none",
momentum = 0,
batch_size = NULL,
class_weights = NULL,
stop_iter = 5,
verbose = FALSE,
device = NULL,
use_target_token = TRUE,
...
)Arguments
- x
Depending on the context:
A data frame of predictors.
A matrix of predictors.
A recipe specifying a set of preprocessing steps created from
recipes::recipe().
The predictor data should be standardized (e.g. centered or scaled).
- ...
Options to pass to the learning rate schedulers via
set_learn_rate(). For example, thereductionorstepsarguments toschedule_step()could be passed here.- y
When
xis a data frame or matrix,yis the outcome specified as:A data frame with 1 column (numeric or factor).
A matrix with numeric column (numeric or factor).
A vector (numeric or factor).
- epochs
An integer for the number of epochs of training.
- num_embedding
An integer for the dimension of the initial embedding layer that encodes the original predictors. Each feature (categorical or continuous) is mapped to a vector of this dimension. Must be >= 1.
- attention_type
A character string for the type of attention to use. Options are:
"column": Column attention only (attends across features). This is the SAINT-s variant."row": Row/inter-sample attention only (attends across samples within a batch). This is the SAINT-i variant."both": Alternates between column and row attention in each transformer block. This is the full SAINT model.
- num_attn_heads
An integer for the number of parallel attention heads used in both column and row attention. Must be >= 1.
- num_attn_blocks
An integer for the number of sequential transformer blocks (depth). Must be >= 1.
- dropout_attn
A number in
[0, 1)for the dropout rate applied to attention weights during training.A number in
[0, 1)for the dropout rate applied within the feed-forward layers of each transformer block.- dropout_last
A number in
[0, 1)for the dropout rate applied between the last hidden layer and the output head. Only has effect whenhidden_unitsis notNULL. Default is 0 (no dropout).- row_attention_on_predict
A logical value. Should row (inter-sample) attention be applied during prediction? Default is
TRUE, matching the training-time behavior. WhenFALSE, row attention is bypassed at predict time so that predictions for a given row do not depend on what other rows are in the prediction set; column attention is used on its own. This is only relevant whenattention_typeis"row"or"both".An integer vector for the number of units in optional hidden layers between the transformer backbone and the output head. When
NULL(the default), no hidden layers are added and the pooled transformer output is projected directly to the output.A character vector of activation functions for the hidden layers. Must be the same length as
hidden_unitsor a single value that will be recycled. Seebrulee_activations()for options.- penalty
The amount of weight decay (i.e., L2 regularization).
- mixture
Proportion of Lasso Penalty (type: double, default: 0.0). A value of mixture = 1 corresponds to a pure lasso model, while mixture = 0 indicates ridge regression (a.k.a weight decay). Must be zero for optimizers
"ADAMw","RMSprop","Adadelta".- validation
The proportion of the data randomly assigned to a validation set.
- optimizer
The method used in the optimization procedure. Possible choices are
"SGD","ADAMw","Adadelta","Adagrad","RMSprop", and"LBFGS"."LBFGS"is the only second-order method and does not use batches.- learn_rate
A positive number that controls the initial rapidity that the model moves along the descent path. Values around 0.1 or less are typical.
- rate_schedule
A single character value for how the learning rate should change as the optimization proceeds. Possible values are
"none"(the default),"decay_time","decay_expo","cyclic"and"step". Seeschedule_decay_time()for more details.- momentum
A positive number usually on
[0.50, 0.99]for the momentum parameter in gradient descent. (optimizers"SGD", and"RMSprop"only, ignored otherwise).- batch_size
An integer for the number of training set points in each batch. (
optimizer != "LBFGS"only, ignored otherwise)- class_weights
Numeric class weights (classification only). The value can be:
A named numeric vector (in any order) where the names are the outcome factor levels.
An unnamed numeric vector assumed to be in the same order as the outcome factor levels.
A single numeric value for the least frequent class in the training data and all other classes receive a weight of one.
- stop_iter
A non-negative integer for how many iterations with no improvement before stopping.
- verbose
A logical that prints out the iteration history.
- device
A single character string for the device to train on (e.g.,
"cpu"or"cuda"for GPU). IfNULL, the function will use the GPU if available, otherwise CPU. See training_efficiency.- use_target_token
A logical value. When
TRUE(the default), a learnable target token ([CLS]in the SAINT paper) is prepended to each sample's feature sequence and only its final-layer embedding is fed to the head. This matches the architecture described in the SAINT paper (Section 3 and Figure 1); see the Target Token Pooling section in Details. WhenFALSE, the head instead consumes the concatenation of every feature token, which matches the SAINT reference implementation at https://github.com/somepago/saint.- formula
A formula specifying the outcome term(s) on the left-hand side, and the predictor term(s) on the right-hand side.
- data
When a recipe or formula is used,
datais specified as:A data frame containing both the predictors and the outcome.
Value
A brulee_saint object with elements:
models_obj: a serialized raw vector for the torch module.estimates: a list of matrices with the model parameter estimates per epoch.best_epoch: an integer for the epoch with the smallest loss.loss: A vector of loss values (MSE for regression, negative log- likelihood for classification) at each epoch.dim: A list of data dimensions and feature metadata.y_stats: A list of summary statistics for numeric outcomes.parameters: A list of some tuning parameter values.device: A character string for the device used during training.blueprint: Thehardhatblueprint data.
Details
Architecture
The SAINT architecture has three stages:
Embedding layer: Categorical features are mapped through per-feature embedding tables. Continuous features are passed through per-feature MLPs (1 -> 100 ->
num_embedding). These initial embeddings are per-feature; there is a distinct embedding MLP for each predictor. Also, see the "Target Token Pooling" section below.Transformer backbone: A stack of
num_attn_blockstransformer layers. Each layer contains multi-head self-attention followed by a feed-forward network with GeGLU activation. Forattention_type = "both", each block alternates between column attention (across features) and row attention (across samples within the batch).Output head: Pools the transformer output (either the target token's embedding or the flattened concatenation of all feature embeddings, controlled by
use_target_token) and projects it through optional hidden layers to the output dimension.
There is a summary() methods that can provide details of the architecture
for a specific model fit.
Differences in this implementation and the orignal paper:
Pretraining isn't supported.
Attention Types
Column attention (
"column"): Standard self-attention over features. Each feature embedding attends to all other feature embeddings.Row attention (
"row"): inter-sample attention. Reshapes the batch so that each sample's full feature representation becomes a single token, then applies attention across all samples in the batch.Both (
"both"): Alternates between column and row attention in each transformer block. This is the full SAINT model.
Target Token Pooling
Borrowing from BERT, SAINT prepends a learnable target token (the
paper calls it [CLS]) to each sample's feature sequence before the
transformer. With embeddings E(x_i^{(1)}), ..., E(x_i^{(n)}) for the
n predictors of sample i, the input sequence becomes
[target, E(x_i^{(1)}), E(x_i^{(2)}), ..., E(x_i^{(n)})]
giving n + 1 tokens of dimension num_embedding. The target token has
no input value; it is a free parameter of the model that is trained
alongside the rest of the network. Column attention lets every feature
token attend to the target and vice versa, so the target slot accumulates
a contextual summary of the sample. When attention_type is "row" or
"both", inter-sample attention sees the full n + 1 token sequence per
sample, so the target slot also exchanges information across samples in
the batch.
After the transformer backbone, the head reads only the final-layer
embedding of the target token (the first position) and feeds it through
the optional hidden_units MLP and the output layer. This is what the
paper describes in Figure 1: "We take the contextual embeddings from
SAINT and pass only the embedding correspond to the CLS token through an
MLP to obtain the final prediction."
With use_target_token = FALSE, no target token is added and the head
instead consumes the concatenation of all n feature tokens. That
option is provided because the SAINT reference Python implementation
(https://github.com/somepago/saint) departs from the paper and uses
flatten-pooling; it is kept available for compatibility with that code
path and for users who want the original brulee behavior.
Row Attention at Prediction Time
Row attention computations adjust the internal embeddings based on the rows
that are available at any given time. During training, the other rows in the
batch are used to compute attention. After training, when predict() is
called, the default behavior is to keep row attention on, mirroring the
training-time computation. Because row attention is computed across the
samples present in a given call, predictions for a row depend on what
other rows are passed alongside it. To get batch-independent predictions
(where the prediction for a given row is the same regardless of what
other rows are in the input), set row_attention_on_predict to FALSE;
row attention is then bypassed at predict time and column attention is
used on its own.
Learning Rates
The learning rate can be set to constant (the default) or dynamically set
via a learning rate scheduler (via the rate_schedule). Using
rate_schedule = 'none' uses the learn_rate argument.
Other Notes
Unlike other brulee models, brulee_saint() natively handles factor
predictors via learned embeddings. Factor columns are automatically detected
and embedded, while numeric columns pass through per-feature MLPs. There is
no need to pre-encode factors as indicators.
When the outcome is a number, the function internally standardizes the outcome data to have mean zero and a standard deviation of one. The prediction function creates predictions on the original scale.
By default, training halts when the validation loss increases for at least
stop_iter iterations. If validation = 0 the training set loss is used.
The model objects are saved for each epoch so that the number of epochs can
be efficiently tuned. The predict() method for this model has an
epoch argument (which defaults to the epoch with the best loss value).
References
Somepalli, G., Goldblum, M., Schwarzschild, A., Bruss, C. B., & Goldstein, T. (2021). SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training. arXiv preprint arXiv:2106.01342.
Examples
# \donttest{
pkgs <- c("recipes", "yardstick", "modeldata")
if (torch::torch_is_installed() & rlang::is_installed(pkgs)) {
set.seed(87261)
tr_data <- modeldata::sim_regression(500, method = "worley_1987")
te_data <- modeldata::sim_regression(50, method = "worley_1987")
library(recipes)
rec <- recipe(outcome ~ ., data = te_data) |>
step_normalize(all_numeric_predictors())
set.seed(389)
fit <- brulee_saint(
rec,
data = te_data,
hidden_unit = 5,
dropout_hidden = 0.2,
num_embedding = 3,
num_attn_heads = 5,
num_attn_blocks = 4,
dropout_attn = 0.2,
epochs = 50L,
batch_size = 32L,
learn_rate = 0.01,
optimize = "SGD",
verbose = TRUE
)
autoplot(fit)
summary(fit)
library(yardstick)
predict(fit, te_data) |>
dplyr::bind_cols(te_data) |>
rsq(outcome, .pred)
}
#> epoch: 00, learn rate: 0.01000, Loss (scaled): 0.311
#> epoch: 01, learn rate: 0.01000, Loss (scaled): 0.295
#> epoch: 02, learn rate: 0.01000, Loss (scaled): 0.288
#> epoch: 03, learn rate: 0.01000, Loss (scaled): 0.289
#> epoch: 04, learn rate: 0.01000, Loss (scaled): 0.298
#> epoch: 05, learn rate: 0.01000, Loss (scaled): 0.315
#> epoch: 06, learn rate: 0.01000, Loss (scaled): 0.289
#> epoch: 07, learn rate: 0.01000, Loss (scaled): 0.305
#> SAINT architecture
#> inputs: 8 (0 categorical, 8 numeric) | output dim: 1
#> attention: both | embedding dim: 3 | target token: TRUE
#>
#> Embedding layer
#> Target token (1 x 3) 3 params
#> 8 x MLP(1 -> 100 -> 3) 4,024 params
#>
#> Transformer backbone (4 blocks, column + row attention)
#> Block 1:
#> Column attention:
#> LayerNorm(3) 6 params
#> Attention(dim=3, heads=5) 963 params
#> LayerNorm(3) 6 params
#> FeedForward(3, GEGLU) 135 params
#> Row attention:
#> LayerNorm(27) 54 params
#> Attention(dim=27, heads=5) 34,587 params
#> LayerNorm(27) 54 params
#> FeedForward(27, GEGLU) 8,991 params
#> Block 2:
#> Column attention:
#> LayerNorm(3) 6 params
#> Attention(dim=3, heads=5) 963 params
#> LayerNorm(3) 6 params
#> FeedForward(3, GEGLU) 135 params
#> Row attention:
#> LayerNorm(27) 54 params
#> Attention(dim=27, heads=5) 34,587 params
#> LayerNorm(27) 54 params
#> FeedForward(27, GEGLU) 8,991 params
#> Block 3:
#> Column attention:
#> LayerNorm(3) 6 params
#> Attention(dim=3, heads=5) 963 params
#> LayerNorm(3) 6 params
#> FeedForward(3, GEGLU) 135 params
#> Row attention:
#> LayerNorm(27) 54 params
#> Attention(dim=27, heads=5) 34,587 params
#> LayerNorm(27) 54 params
#> FeedForward(27, GEGLU) 8,991 params
#> Block 4:
#> Column attention:
#> LayerNorm(3) 6 params
#> Attention(dim=3, heads=5) 963 params
#> LayerNorm(3) 6 params
#> FeedForward(3, GEGLU) 135 params
#> Row attention:
#> LayerNorm(27) 54 params
#> Attention(dim=27, heads=5) 34,587 params
#> LayerNorm(27) 54 params
#> FeedForward(27, GEGLU) 8,991 params
#>
#> Hidden layers
#> Linear(3 -> 5) 20 params
#> ReLU 0 params
#>
#> Output head
#> Linear(5 -> 1) 6 params
#>
#> Total parameters: 183,237
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 rsq standard 0.00145
# }
