Title: | Explainable Boosting Machines |
---|---|
Description: | An interface to the 'Python' 'InterpretML' framework for fitting explainable boosting machines (EBMs); see Nori et al. (2019) <doi:10.48550/arXiv.1909.09223> for. EBMs are a modern type of generalized additive model that use tree-based, cyclic gradient boosting with automatic interaction detection. They are often as accurate as state-of-the-art blackbox models while remaining completely interpretable. |
Authors: | Brandon M. Greenwell [aut, cre]
|
Maintainer: | Brandon M. Greenwell <[email protected]> |
License: | MIT + file LICENSE |
Version: | 0.1.0 |
Built: | 2025-03-06 06:20:54 UTC |
Source: | https://github.com/bgreenwell/ebm |
If possible, coerces its input to an ebm object.
as.ebm(x, ...)
as.ebm(x, ...)
x |
An object that inherits from class
|
... |
Additional optional arguments. (Currently ignored.) |
An ebm object that can be used with the associated methods
plot()
and so forth.
This function is an R wrapper for the explainable boosting functions in the Python interpret library. It trains an Explainable Boosting Machine (EBM) model, which is a tree-based, cyclic gradient boosting generalized additive model with automatic interaction detection. EBMs are often as accurate as state-of-the-art blackbox models while remaining completely interpretable.
ebm( formula, data, max_bins = 1024L, max_interaction_bins = 64L, interactions = 0.9, exclude = NULL, validation_size = 0.15, outer_bags = 16L, inner_bags = 0L, learning_rate = 0.04, greedy_ratio = 10, cyclic_progress = FALSE, smoothing_rounds = 500L, interaction_smoothing_rounds = 100L, max_rounds = 25000L, early_stopping_rounds = 100L, early_stopping_tolerance = 1e-05, min_samples_leaf = 4L, min_hessian = 0, reg_alpha = 0, reg_lambda = 0, max_delta_step = 0, gain_scale = 5, min_cat_samples = 10L, cat_smooth = 10, missing = "separate", max_leaves = 2L, monotone_constraints = NULL, objective = c("auto", "log_loss", "rmse", "poisson_deviance", "tweedie_deviance:variance_power=1.5", "gamma_deviance", "pseudo_huber:delta=1.0", "rmse_log"), n_jobs = -1L, random_state = 42L, ... )
ebm( formula, data, max_bins = 1024L, max_interaction_bins = 64L, interactions = 0.9, exclude = NULL, validation_size = 0.15, outer_bags = 16L, inner_bags = 0L, learning_rate = 0.04, greedy_ratio = 10, cyclic_progress = FALSE, smoothing_rounds = 500L, interaction_smoothing_rounds = 100L, max_rounds = 25000L, early_stopping_rounds = 100L, early_stopping_tolerance = 1e-05, min_samples_leaf = 4L, min_hessian = 0, reg_alpha = 0, reg_lambda = 0, max_delta_step = 0, gain_scale = 5, min_cat_samples = 10L, cat_smooth = 10, missing = "separate", max_leaves = 2L, monotone_constraints = NULL, objective = c("auto", "log_loss", "rmse", "poisson_deviance", "tweedie_deviance:variance_power=1.5", "gamma_deviance", "pseudo_huber:delta=1.0", "rmse_log"), n_jobs = -1L, random_state = 42L, ... )
formula |
A formula of the form |
data |
A data frame containing the variables in the model. |
max_bins |
Max number of bins per feature for the main effects stage. Default is 1024. |
max_interaction_bins |
Max number of bins per feature for interaction terms. Default is 64. |
interactions |
Interaction terms to be included in the model. Default is 0.9. Current options include:
|
exclude |
Features or terms to be excluded. Default is |
validation_size |
Validation set size. Used for early stopping during boosting, and is needed to create outer bags. Default is 0.15. Options are:
|
outer_bags |
Number of outer bags. Outer bags are used to generate error bounds and help with smoothing the graphs. |
inner_bags |
Number of inner bags. Default is 0 which turns off inner bagging. |
learning_rate |
Learning rate for boosting. Deafult is 0.04. |
greedy_ratio |
The proportion of greedy boosting steps relative to cyclic boosting steps. A value of 0 disables greedy boosting, effectively turning it off. Default is 10. |
cyclic_progress |
This parameter specifies the proportion of the
boosting cycles that will actively contribute to improving the model's
performance. It is expressed as a logical or numeric between 0 and 1, with
the default set to |
smoothing_rounds |
Number of initial highly regularized rounds to set the basic shape of the main effect feature graphs. Default is 500. |
interaction_smoothing_rounds |
Number of initial highly regularized rounds to set the basic shape of the interaction effect feature graphs during fitting. Default is 100. |
max_rounds |
Total number of boosting rounds with |
early_stopping_rounds |
Number of rounds with no improvement to trigger
early stopping. 0 turns off early stopping and boosting will occur for
exactly |
early_stopping_tolerance |
Tolerance that dictates the smallest delta
required to be considered an improvement which prevents the algorithm from
early stopping. |
min_samples_leaf |
Minimum number of samples allowed in the leaves. Default is 4. |
min_hessian |
Minimum hessian required to consider a potential split valid. Default is 0.0. |
reg_alpha |
L1 regularization. Default is 0.0. |
reg_lambda |
L2 regularization. Default is 0.0. |
max_delta_step |
Used to limit the max output of tree leaves; <=0.0 means no constraint. Default is 0.0. |
gain_scale |
Scale factor to apply to nominal categoricals. A scale factor above 1.0 will cause the algorithm focus more on the nominal categoricals. Default is 5.0. |
min_cat_samples |
Minimum number of samples in order to treat a category separately. If lower than this threshold the category is combined with other categories that have low numbers of samples. Default is 10. |
cat_smooth |
Used for the categorical features. This can reduce the effect of noises in categorical features, especially for categories with limited data. Default is 10.0. |
missing |
Method for handling missing values during boosting. Default is
|
max_leaves |
Maximum number of leaves allowed in each tree. Default is 2. |
monotone_constraints |
Default is NULL. This parameter allows you to
specify monotonic constraints for each feature's relationship with the target
variable during model fitting. However, it is generally recommended to apply
monotonic constraints post-fit using the
|
objective |
The objective function to optimize. Current options include:
Default is |
n_jobs |
Number of jobs to run in parallel. Default is -1. Negative
integers are interpreted as following
joblib's formula ( |
random_state |
Random state. Setting to |
... |
Additional optional argument. (Currently ignored.) |
In short, EBMs have the general form
where,
is a link function that allows the model to handle various response
types (e.g., the logit link for logistic regression or Poisson deviance for
modeling counts and rates);
is a constant intercept (or bias term);
?
is the term contribution (or shape function) for predictor
(i.e., it captures the main effect of
on
);
is the term contribution for the pair of predictors
and
(i.e., it captures the joint effect, or pairwise interaction
effect of
and
on
).
An object of class "EBM"
for which there are print,
predict, plot, and merge methods.
## Not run: # # Regression example # # Fit a default EBM regressor fit <- ebm(mpg ~ ., data = mtcars, objective = "rmse") # Generate some predictions head(predict(fit, newdata = mtcars)) head(predict(fit, newdata = mtcars, se_fit = TRUE)) # Show global summary and GAM shape functions plot(fit) # term importance scores plot(fit, term = "cyl") plot(fit, term = "cyl", interactive = TRUE) # Explain prediction for first observation plot(fit, local = TRUE, X = subset(mtcars, select = -mpg)[1L, ]) ## End(Not run)
## Not run: # # Regression example # # Fit a default EBM regressor fit <- ebm(mpg ~ ., data = mtcars, objective = "rmse") # Generate some predictions head(predict(fit, newdata = mtcars)) head(predict(fit, newdata = mtcars, se_fit = TRUE)) # Show global summary and GAM shape functions plot(fit) # term importance scores plot(fit, term = "cyl") plot(fit, term = "cyl", interactive = TRUE) # Explain prediction for first observation plot(fit, local = TRUE, X = subset(mtcars, select = -mpg)[1L, ]) ## End(Not run)
A combination of geom_ribbon() and geom_step().
geom_stepribbon( mapping = NULL, data = NULL, stat = "identity", position = "identity", na.rm = FALSE, show.legend = NA, inherit.aes = TRUE, ... )
geom_stepribbon( mapping = NULL, data = NULL, stat = "identity", position = "identity", na.rm = FALSE, show.legend = NA, inherit.aes = TRUE, ... )
mapping |
Set of aesthetic mappings created by |
data |
The data to be displayed in this layer. There are three options: If A A |
stat |
The statistical transformation to use on the data for this
layer, either as a |
position |
Position adjustment, either as a string naming the adjustment
(e.g. |
na.rm |
If |
show.legend |
logical. Should this layer be included in the legends?
|
inherit.aes |
If |
... |
Other arguments passed on to |
Taken from ldatools.
This function will install interpret along with all of its dependencies.
install_interpret( envname = "r-ebm", ..., extra_packages = c("plotly>=3.8.1"), python_version = ">=3.9,<=3.12", restart_session = TRUE )
install_interpret( envname = "r-ebm", ..., extra_packages = c("plotly>=3.8.1"), python_version = ">=3.9,<=3.12", restart_session = TRUE )
envname |
Name of or path to a Python virtual environment. |
... |
Additional optional arguments. (Currently ignored.) |
extra_packages |
Additional Python packages to install alongside interpret. |
python_version |
Passed on to virtualenv_starter(). |
restart_session |
Whether to restart the R session after installing (note this will only occur within RStudio). |
No return value, called for side effects.
Merge multiple EBMs together.
## S3 method for class 'EBM' merge(x, y, ...)
## S3 method for class 'EBM' merge(x, y, ...)
x , y
|
Fitted ebm objects that have been trained on similar data sets that have the same set of features. |
... |
Additional ebm objects to be merged. |
A merged ebm object.
As of right now, the merge()
function produces the following error
message:
Error in py_repr(x) : AttributeError: 'ExplainableBoostingRegressor' object has no attribute 'cat_smooth' Run `reticulate::py_last_error()` for details.
This seems to be a bug in the underlying interpret library and does not prevent this function from working. The error message is seemingly just a side effect.
## Not run: # Generate list of EBMs with different random seeds ebms <- lapply(1:3, FUN = function(i) { ebm(mpg ~ ., data = mtcars, outer_bags = 1, random_state = i, obj = "rmse") }) # Merge EBMs into one and plot term contribution for `cyl` merged <- do.call(merge, args = ebms) plot(merged, term = "cyl") ## End(Not run)
## Not run: # Generate list of EBMs with different random seeds ebms <- lapply(1:3, FUN = function(i) { ebm(mpg ~ ., data = mtcars, outer_bags = 1, random_state = i, obj = "rmse") }) # Merge EBMs into one and plot term contribution for `cyl` merged <- do.call(merge, args = ebms) plot(merged, term = "cyl") ## End(Not run)
Provides an interactive visualization for a given explanation(s).
## S3 method for class 'EBM' plot( x, term = NULL, local = FALSE, X = NULL, y = NULL, init_score = NULL, interactive = FALSE, n_terms = NULL, geom = c("point", "col", "bar"), mapping = NULL, aesthetics = list(), horizontal = FALSE, uncertainty = TRUE, width = 0.5, alpha = 0.5, fill = "grey", display = c("viewer", "markdown", "url"), viewer = c("browser", "rstudio"), full_dashboard = FALSE, ... )
## S3 method for class 'EBM' plot( x, term = NULL, local = FALSE, X = NULL, y = NULL, init_score = NULL, interactive = FALSE, n_terms = NULL, geom = c("point", "col", "bar"), mapping = NULL, aesthetics = list(), horizontal = FALSE, uncertainty = TRUE, width = 0.5, alpha = 0.5, fill = "grey", display = c("viewer", "markdown", "url"), viewer = c("browser", "rstudio"), full_dashboard = FALSE, ... )
x |
A fitted |
term |
Character string specifying which term to plot. For interaction
effect, you can supply a pair (e.g., |
local |
Logocial indicating whether to display local explanations
( |
X |
Data frame or matrix of samples. Unless |
y |
Optional vector of response values corresponding to |
init_score |
Optional. Either a model that can generate scores or
per-sample initialization score. If samples scores it should be the same
length as |
interactive |
Logical indicating whether to produce an interactive plot
based on HTML. Default is |
n_terms |
Integer specifying the maximum number of variable
importance scores to plot. Default is |
geom |
Character string specifying which type of plot to construct for terms associated with categorical features. Current options are:
Default is |
mapping |
Set of aesthetic mappings created by aes-related functions and/or tidy eval helpers. See example usage below. |
aesthetics |
List specifying additional arguments passed on to
layer. These are often aesthetics, used to set an aesthetic
to a fixed value, like |
horizontal |
Logical indicating whether or not term plots for
categorical features should be flipped horzintally. Default is |
uncertainty |
Logical indicating whether or not to also display
uncertainty via error bars on the main effect plots. Default is |
width |
Numeric specifying the width of the error bars displayed in bar/ dot plots for categorical features. Default is 0.5. |
alpha |
Numeric between 0 and 1 specifying the level of transparency to use when displaying uncertainty in plots for continuous features. Default is 0.5. |
fill |
Character string specifying the fill color to use when displaying
uncertainty in plots for continuous features. Default is |
display |
Character string specifying how the results should be
displayed whenever |
viewer |
Character string specifying how the results should be viewed.
Current choices are |
full_dashboard |
Logical indicating whether or not to display the full
interpret dashboard. Default is |
... |
Additional optional arguments. Currently only passed onto levelplot() for heatmaps of interaction effects. |
When interactive = FALSE
(the default), the output is either a
ggplot object when visualizing term importance scores or
main effects, or a trellis object when visualizing
pairwise interaction effects. When interactive = TRUE
, the return value
depends on display
argument. When display = "url"
, a character string
is returned giving the URL for displaying the HTML-based visualization.
Otherwise, the results are viewed as requested (i.e., in a browser, built-in
viewer, or displayed in rendered HTML output).
Compute predicted values from a fitted explainable boosting machine.
## S3 method for class 'EBM' predict( object, newdata, type = c("response", "link", "class", "terms"), se_fit = FALSE, init_score = NULL, ... )
## S3 method for class 'EBM' predict( object, newdata, type = c("response", "link", "class", "terms"), se_fit = FALSE, init_score = NULL, ... )
object |
A fitted ebm object. |
newdata |
A data frame in which to look for variables with which to predict. |
type |
The type of prediction required. Current options include:
|
se_fit |
Logical indicating whether or not standard errors are required. Ignored for multiclass outcomes. Note that standard errors are only available on the link scale. |
init_score |
Optional. Either a model that can generate scores or
per-sample initialization score. If samples scores it should be the same
length as |
... |
Additional optional arguments. (Currently ignored.) |
Either a vector, matrix, or list of results. See the type
argument
for details.
Display basic information about a fitted ebm object.
## S3 method for class 'EBM' print(x, digits = max(3L, getOption("digits") - 3L), ...)
## S3 method for class 'EBM' print(x, digits = max(3L, getOption("digits") - 3L), ...)
x |
A fitted ebm object. |
digits |
The number of significant digits to be passed to |
... |
Additional optional arguments to be passed to |
Invisibly returns the printed ebm object.