Using flashlight

Michael Mayer

2019-10-12

library(flashlight)      # model interpretation
library(MetricsWeighted) # metrics
library(dplyr)           # data prep
library(moderndive)      # data
library(caret)           # data split
library(xgboost)         # gradient boosting
library(ranger)          # random forest

Introduction

In contrast to classic statistical modelling techniques like linear regression, modern machine learning approaches tend to provide black box results. In areas like visual computing or natural language processing, this is not an issue since there, focus usually lies on predicting things. Either the predictions are sufficiently useful in practice or the model won’t be used. However, in areas where the purpose of a statistical model is also to explain or validate underlying theories (e.g. in medicine, economics, and biology), black box models are of little use.

Thus, there is need to shed light into these black boxes resp. to explain machine learning models as good as possible. An excellent reference is the online book of Christoph Molnar [1]. Of special interest are model agnostic approaches that work for any kind of modelling technique, e.g. a linear regression, a neural net or a tree-based method. The only requirement is the availability of a prediction function, i.e. a function takes a data set and returns predictions.

To do so is the purpose of the R package flashlight, which is inspired by the beautiful DALEX package, see (https://CRAN.R-project.org/package=DALEX).

The main props of flashlight:

  1. It is simple, yet flexible.

  2. It offers model agnostic tools like model performance, variable importance, ICE profiles, partial dependence, further effects plots, and variable contribution breakdown for single observations.

  3. It allows to assess multiple models in parallel.

  4. It supports “group by” operations.

  5. It works with case weights.

Currently, models with numeric or binary response are supported.

We will now give a brief introduction to machine learning explanations and then illustrate them with the flashlight package.

Background

Important model agnostic machine learning explanations include the following aspects, amongst many other.

Model performance

How precise are models if applied to unseen data? This aspect is of key interest of basically any supervised machine learning model and helps to identify the best models or, if applied to subgroup, identify problematic segments with low performance.

Variable importance

Which variables are particularly relevant for the model? This aspect is helpful in different ways. Firstly, it might help to simplify the full modelling process by eliminating difficult to assess input variables with low explanatory power. Secondly, its pure information. Thirdly, it might help to identify problems in data structure: if one variable is extremely relevant and all others not, then there might be some sort of information leakage from the response. Thus said, variable importance considerations are very relevant for quality assurance as well.

Different modelling techniques offer different ways of variable importance. In linear models, we consider e.g. F-test statistics or p values, in tree-based methods, its number of splits or avarage split gains etc. A model agnostic way to assess this is called permutation importance: For each input variable \(X\), its values are randomly shuffled and the drop in performance with respect to a scoring function is calculated. The more important a variable, the larger the drop. If a variable can be shuffled without any impact on model precision, it is completely irrelevant. The method is described in Fisher et al. 2018 [2].

Effects of input variables

In linear regression, the fitted model consists of an affine linear function in the inputs. Its coefficients immediately tell us how the response is expected to change if the value of one single input variable \(X\) is systematically being adapted. How to describe such effects of a variable \(X\) for more complex models that include non-linearities and high-order interactions?

One approach that is to study Individual Conditional Expectation (ICE) profiles of selected observations: They show how predictions of observation \(i\) react when the input variable \(X\) is systematically being changed, see [3]. The more different the profiles are in shape or slope, the stronger are the interaction effects. For a linear regression without interactions, all such profiles would be parallel. ICE profiles centered at one point (or “c-ICE” profiles) help to detect interactions even better.

If many ICE profiles are averaged, we get partial dependence profiles which can be viewed as the average effect of variable \(X\), pooled over all interactions. Partial dependence plots where introduced in Friedman’s seminal 2001 article on gradient boosting [4].

Studying ICE and partial dependence profiles make sense as long as it make sense to investigate the effect of \(X\) while holding all other predictors fixed. The unnatural such ceteris paribus assumption is, the less reasonable are the results of ICE and partial dependence. Accumulated local effects or ALE profiles [5] try to overcome this weakness of ICE and partial dependence considerations.

ALE profiles at positions \(x_i\) are calculated as follows:

  1. First, left derivatives \(\Delta_i\) are estimated for all \(i\): Calculate slope of partial dependence line from \(x_i-\varepsilon\) to \(x\) based on data points in \([x_i-\varepsilon, x_i]\).

  2. The uncalibrated ALE value at \(x_i\) is the (integrated/accumulated) partial sum \(\sum_{j \le i} \Delta_j\).

  3. In the calibration step, the effects are shifted to the level of the response variable or its predictions.

An alternative to partial dependence and ALE profiles is to look at combined effects of \(X\) including the effects from all other predictors. Such effects are estimated by averaging the predictions within values of the predictor \(X\) of interest. If such prediction profile differs considerably from the observed average response, this might be a sign of model underfit. This visualization is sometimes called “marginal plot” or “M plot”, see [5]. Residual profiles immediatly show such misfits as well. In classic statistical modelling, this sort of plots are called “fitted versus covariable” plots or “residual versus covariable” plots.

Besides looking at average profiles, it is often also revealing to consider quartile profiles or to visualize partial dependence, response and prediction profiles in the same plot.

Variable contribution breakdown for single observations

Instead of studying global variable importance and effects, there are different techniques to entangle how a single prediction can be decomposed into additive effects of input variables, e.g. LIME, LIVE, SHAP or breakdown (Gosiewska and Biecek [6]), see [1] and [6] for an overview. The flashlight package currently supports the breakdown method.

It works as follows: First, the visit order \((x_1, ..., x_m)\) of variables is specified. Then, in the query data, the column \(x_1\) is set to the constant value of \(x_1\) of the observation to be explained. The change in the (weighted) average predicted value on the query data measures the contribution of \(x_1\) on the prediction. This procedure is iterated over all \(x_i\) until eventually, all rows in the query data are identical to the observation to be explained.

A complication with this approach is that the visit order is relevant, at least for non-additive models. Ideally, the algorithm could be repeated for all \(m!\) possible visit orders and its results averaged per variable. This is basically what SHAP values do, see e.g. [6] for an explanation. Unfortunately, there is no efficient way to do this in a model agnostic way. Thus we use the short-cut described in [6] and implemented in the ibreakdown package (https://CRAN.R-project.org/package=iBreakDown): There, the variables \(x_i\) are sorted by the size of their contribution in the same way as the breakdown algorithm but without iteration, i.e. starting from the original query data for each variable \(x_i\).

The flashlight package offers these tools in a very simple way.

Installation of flashlight

From CRAN:

install.packages("flashlight")

Latest version from github:

library(devtools)
install_github("mayer79/flashlight")

Teaser

Let’s start with an iris example.

# Fit model
fit <- lm(Sepal.Length ~ ., data = iris)

# Make flashlight
fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols",
                 metrics = list(rmse = rmse, `R-squared` = r_squared))

# Performance: rmse and R-squared
plot(light_performance(fl), fill = "darkred")

plot(light_performance(fl, by = "Species"), fill = "darkred")


# Variable importance by increase in rmse
imp <- light_importance(fl)
plot(imp, fill = "darkred")

plot(light_importance(fl, by = "Species")) +
   scale_fill_viridis_d(begin = 0.2, end = 0.8)

most_important(imp, 2)
#> [1] "Petal.Length" "Species"

# ICE profiles for Petal.Width
plot(light_ice(fl, v = "Petal.Width"))

plot(light_ice(fl, v = "Petal.Width", center = TRUE))

plot(light_ice(fl, v = "Petal.Width", by = "Species"))


# Partial dependence profiles for Petal.Width
plot(light_profile(fl, v = "Petal.Width"))

plot(light_profile(fl, v = "Petal.Width", by = "Species"))


# Accumulated local effects (ALE) profiles for Petal.Width
plot(light_profile(fl, v = "Petal.Width", type = "ale"))

plot(light_profile(fl, v = "Petal.Width", by = "Species", type = "ale"))


# Prediction, response and residual profiles
plot(light_profile(fl, v = "Petal.Width", type = "response", stats = "quartiles"))

plot(light_profile(fl, v = "Petal.Width", type = "predicted"))

plot(light_profile(fl, v = "Petal.Width", type = "residual", stats = "quartiles"))


# Response profiles, prediction profiles, partial depencence, and ALE profiles in one
plot(light_effects(fl, v = "Petal.Width"), use = "all")


# Variable contribution breakdown for single observation
plot(light_breakdown(fl, new_obs = iris[2, ]))

flashlights and multiflashlights

The process of using the flashlight package is as follows:

  1. Define a flashlight for each model. This is basically a list with optional components relevant for model interpretation:

    • model: The fitted model object like e.g. the one returned by lm.

    • data: A data set used to evaluate model agnostic tools, e.g. the validation data.

    • y: The name of the variable in data representing the model response.

    • predict_function: A function taking model and data and returning numeric predictions.

    • linkinv: Inverse link function used to retransform the values returned by predict_function. Defaults to the identity function function(z) z.

    • w: The name of the variable in data representing the case weights.

    • by: A character vector of names of grouping variables in data. These will be used to stratify all results.

    • metrics: A named list of metrics. These functions need to be available in the workspace and require arguments actual, predicted, w (case weights) as well as a placeholder … for further arguments. All metrics available in R package MetricsWeighted are suitable.

    • label: The label of the model. This is the only required input when building the flashlight.

  2. Calculate relevant information by calling the key functions:

    • light_performance: Calculates performance measures regarding different metrics, possibly within subgroups and weighted by case weights.

    • light_importance: Calculates variable importance (worsening in performance by random shuffling) for each or a subset of variables. Possibly within subgroups and using case weights. The most important variable names can be extracted by the function most_important on the result of light_importance.

    • light_ice: Calculates ICE profiles across a couple of observations, possibly within groups.

    • light_profile: Calculates partial dependent profiles across a covariable, possibly within groups. Generated by calling light_ice and aggregating the results. The function is flexible: it can also be used to generate ALE, response, residual or prediction profiles or calculate (weighted) quartiles instead of (weighted) means.

    • light_effects: Combines partial dependence, response and prediction profiles. ALE profiles can be added as well.

    • light_breakdown: Calculates variable contribution breakdown for a single observation.

  3. Plot the result: Each of these functions offer a plot method with minimal visualization of the results through ggplot2. The resulting plot can be customized by adding theme and other ggplot elements. If customization is insufficient, you can extract the data slot in the object returned by above key functions and build an own plot.

In practice, multiple flashlights are being defined and evaluated in parallel. By the help of a multiflashlight object, The flashlight packages provides as much support as possible to avoid any redundancy. It can be used to combine fully specified flashlights or, and that is the more interesting option, take minimally defined flashlights (e.g. only label, model and predict_function) and add common arguments like y, by, data and/or w (case weights) in calling multiflashlight. If necessary, the resulting completed flashlights contained in the multiflashlight can be extracted again by $.

All key functions are defined for both flashlight and multiflashlight objects.

Example

As illustration, we use the data set house_prices with information on 21613 houses sold in King County between May 2014 and May 2015. It is shipped along with R package moderndive.

The first few observations look as follows:

head(house_prices)
#> # A tibble: 6 x 21
#>   id    date        price bedrooms bathrooms sqft_living sqft_lot floors
#>   <chr> <date>      <dbl>    <int>     <dbl>       <int>    <int>  <dbl>
#> 1 7129~ 2014-10-13 2.22e5        3      1           1180     5650      1
#> 2 6414~ 2014-12-09 5.38e5        3      2.25        2570     7242      2
#> 3 5631~ 2015-02-25 1.80e5        2      1            770    10000      1
#> 4 2487~ 2014-12-09 6.04e5        4      3           1960     5000      1
#> 5 1954~ 2015-02-18 5.10e5        3      2           1680     8080      1
#> 6 7237~ 2014-05-12 1.23e6        4      4.5         5420   101930      1
#> # ... with 13 more variables: waterfront <lgl>, view <int>,
#> #   condition <fct>, grade <fct>, sqft_above <int>, sqft_basement <int>,
#> #   yr_built <int>, yr_renovated <int>, zipcode <fct>, lat <dbl>,
#> #   long <dbl>, sqft_living15 <int>, sqft_lot15 <int>

Thus we have access to many relevant infos like size, condition as well as location of the objects. We want to use these variables to predict the (log) house prices by the help of the following regression techniques and shed some light on them:

We use 70% of the data to calculate the models, 20% for evaluating their performance and for explaining them. 10% we keep untouched.

Data preparation

Let’s do some data preparation common for all models under consideration.

prep <- transform(house_prices, 
                  log_price = log(price),
                  grade = as.integer(as.character(grade)),
                  year = factor(lubridate::year(date)),
                  age = lubridate::year(date) - yr_built,
                  zipcode = as.numeric(as.character(zipcode)),
                  waterfront = factor(waterfront, levels = c(FALSE, TRUE), labels = c("no", "yes")))

x <- c("grade", "year", "age", "sqft_living", "sqft_lot", "zipcode", 
       "condition", "waterfront")

Modelling

The random forest can directly work with this data structure. However, for the linear model, we need a small function with additional feature engineering, i.e. log transforming some input and categorizing the zipcode in large groups. Similarly, for XGBoost, such wrapper function turns non-numeric input variables to numeric. We will make use of these functions for both data preparation and prediction.

# Data wrapper for the linear model
prep_lm <- function(data) {
  data %>% 
    mutate(sqrt_living = log(sqft_living),
           sqrt_lot = log(sqft_lot),
           zipcode = factor(zipcode %/% 10))
}

# Data wrapper for xgboost
prep_xgb <- function(data, x) {
  data %>% 
    select_at(x) %>% 
    mutate_if(Negate(is.numeric), as.integer) %>% 
    data.matrix()
}

Then, we split the data and train our models.

# Train / valid / test split (70% / 20% / 10%)
set.seed(56745)
ind <- caret::createFolds(prep[["log_price"]], k = 10, list = FALSE)

train <- prep[ind >= 4, ]
valid <- prep[ind %in% 2:3, ]
test <- prep[ind == 1, ]

(form <- reformulate(x, "log_price"))
#> log_price ~ grade + year + age + sqft_living + sqft_lot + zipcode + 
#>     condition + waterfront
fit_lm <- lm(update.formula(form, . ~ . + I(sqft_living^2)), data = prep_lm(train))

# Random forest
fit_rf <- ranger(form, data = train, seed = 8373)
cat("R-squared OOB:", fit_rf$r.squared)
#> R-squared OOB: 0.7842605

# Gradient boosting
dtrain <- xgb.DMatrix(prep_xgb(train, x), label = train[["log_price"]])
dvalid <- xgb.DMatrix(prep_xgb(valid, x), label = valid[["log_price"]])

params <- list(learning_rate = 0.5,
               max_depth = 6,
               alpha = 1,
               lambda = 1,
               colsample_bytree = 0.8)

fit_xgb <- xgb.train(params, 
                     data = dtrain,
                     watchlist = list(train = dtrain, valid = dvalid),
                     nrounds = 200, 
                     print_every_n = 100,
                     objective = "reg:linear",
                     seed = 2698)
#> [1]  train-rmse:6.291441 valid-rmse:6.284242 
#> [101]    train-rmse:0.135278 valid-rmse:0.184681 
#> [200]    train-rmse:0.114447 valid-rmse:0.187792

Creating the flashlights

Let’s initialize a flashlight per model. Thanks to individual prediction functions, any model can be used in flashlight, even h2o and keras models.

fl_mean <- flashlight(model = mean(train$log_price), label = "mean", 
                      predict_function = function(mod, X) rep(mod, nrow(X)))
fl_lm <- flashlight(model = fit_lm, label = "lm", 
                    predict_function = function(mod, X) predict(mod, prep_lm(X)))
fl_rf <- flashlight(model = fit_rf, label = "rf",
                    predict_function = function(mod, X) predict(mod, X)$predictions)
fl_xgb <- flashlight(model = fit_xgb, label = "xgb",
                     predict_function = function(mod, X) predict(mod, prep_xgb(X, x)))
print(fl_xgb)
#> 
#> Flashlight xgb 
#> 
#> Model:            Yes
#> y:            No
#> w:            No
#> by:           No
#> data dim:         No
#> predict_fct default:  FALSE
#> linkinv default:  TRUE
#> metrics:      rmse

What about all other relevant elements of a flashlight like the underlying data, the response name, metrics, retransformation functions etc? We could pass them to each of our flashlights. Or, we can combine the flashlights to a multiflashlight and pass additional common arguments there.

fls <- multiflashlight(list(fl_mean, fl_lm, fl_rf, fl_xgb), y = "log_price", linkinv = exp, 
                       data = valid, metrics = list(rmse = rmse, `R-squared` = r_squared))

We could even extract these completed flashlights from the multiflashlight as if the latter is a list (actually it is a list with additional class multiflashlight).

fl_lm <- fls$lm

Assess performance

Let’s compare the models regarding their validation performance.

perf <- light_performance(fls)
perf
#> 
#> I am an object with class(es) light_performance_multi, light_performance, light, list 
#> 
#> Tibbles:
#> 
#>  data 
#> # A tibble: 8 x 3
#>   metric        value label
#>   <fct>         <dbl> <fct>
#> 1 rmse       0.522    mean 
#> 2 R-squared -0.000116 mean 
#> # ... with 6 more rows
plot(perf)

Surprise, surprise: XGBoost is the winner! Now, black bars look a bit sad. Furthermore we would like to remove the x label.

plot(perf, fill = "darkred") +
  xlab(element_blank())

The plot “politics” of flashlight is to provide simple graphics with minimal ggplot-tuning, so you are able to add your own modifications. If you are completely unhappy about the proposed plot (e.g. rather favour a scatterplot over a barplot), extract the data slot of perf and create the figure from scratch:

head(perf$data)
#> # A tibble: 6 x 3
#>   metric        value label
#>   <fct>         <dbl> <fct>
#> 1 rmse       0.522    mean 
#> 2 R-squared -0.000116 mean 
#> 3 rmse       0.290    lm   
#> 4 R-squared  0.691    lm   
#> 5 rmse       0.241    rf   
#> 6 R-squared  0.786    rf

perf$data %>% 
  ggplot(aes(x = label, y = value, group = metric, color = metric)) +
  geom_point() +
  scale_color_viridis_d(begin = 0.2, end = 0.6)

The same logic holds for all other main functions in the flashlight package.

For performance considerations, the minimum required info in the (multi-)flashlight are: “y”, “predict_function”, “model”, “data” and “metrics”. The latter two can also be passed on the fly.

Variable importance

Now let’s study variable importance of the explainers. By default, it is shown with respect to the first metric in the explainers. In our case, its the root-mean-squared error.

(imp <- light_importance(fls, n_max = 1000))
#> 
#> I am an object with class(es) light_importance_multi, light_importance, light, list 
#> 
#> Tibbles:
#> 
#>  data 
#> # A tibble: 92 x 6
#>   variable metric value_shuffled label value_original value
#>   <chr>    <fct>           <dbl> <fct>          <dbl> <dbl>
#> 1 id       rmse            0.530 mean           0.530     0
#> 2 date     rmse            0.530 mean           0.530     0
#> # ... with 90 more rows
plot(imp)

Oops, what happened? Too many variables were tested for permutation drop in rmse, namely all in the data set, except the response. While this can be useful in certain situations, we will just pass the vector x of covariables. Furthermore we replace the metric to mean-squared error.

(imp <- light_importance(fls, v = x, metric = list(mse = mse)))
#> 
#> I am an object with class(es) light_importance_multi, light_importance, light, list 
#> 
#> Tibbles:
#> 
#>  data 
#> # A tibble: 32 x 6
#>   variable metric value_shuffled label value_original value
#>   <chr>    <fct>           <dbl> <fct>          <dbl> <dbl>
#> 1 grade    mse             0.272 mean           0.272     0
#> 2 year     mse             0.272 mean           0.272     0
#> # ... with 30 more rows
plot(imp, fill = "darkred")

If we want to just extract the names of the most relevant three variables, we just do the following:

most_important(imp, top_m = 3)
#> [1] "grade"       "sqft_living" "zipcode"

What about drop in R-squared? You don’t have to update the multiflashlight with that new property. Instead, you can pass it to light_importance on the fly. flashlight does not know if higher or lower values in the scoring function are better, so you will need to pass that information manually.

imp_r2 <- light_importance(fls, metric = list(r_squared = r_squared), 
                           v = x, lower_is_better = FALSE)
plot(imp_r2, fill = "darkred") +
  ggtitle("Drop in R-squared")

Minimal required elements in the (multi-)flashlight are the same as in light_performance.

Note: If the calculations take too long (e.g. large query data), set n_max to some reasonable value. light_importance will then randomly pick n_max rows and use only these for assessment of importance.

Individual conditional expectation

How do predictions change when sqft_living changes alone? We can investigate this question by looking at “Individual Conditional Expectation” (ICE) profiles of a couple of observations.

cp <- light_ice(fls, v = "sqft_living", n_max = 30, seed = 35)
plot(cp, alpha = 0.2)

The XGBoost profiles look wild - for real applications, setting monotonicity constraints would be an idea.

Note: Setting seed to a fixed value will ensure that the flashlights will consider the same rows. An alternative would be to pass a small subset of the data to light_ice and calculate all profiles or by passing row indices through indices for fixed selection.

Centered ICE profiles (“c-ICE”) can help to increase visibility of interactions.

cp <- light_ice(fls, v = "sqft_living", n_max = 30, seed = 35, center = TRUE)
plot(cp, alpha = 0.2)

Partial dependence profiles

If many ICE profiles (in our case 1000) are averaged, we get an impression on the average effect of the considered variable. Such curves are called partial dependence profiles (PD) resp. partial dependence plots.

pd <- light_profile(fls, v = "sqft_living")
pd
#> 
#> I am an object with class(es) light_profile_multi, light_profile, light, list 
#> 
#> Tibbles:
#> 
#>  data 
#> # A tibble: 36 x 5
#>   sqft_living counts   value label type              
#>         <dbl>  <int>   <dbl> <fct> <fct>             
#> 1         500   1000 464764. mean  partial dependence
#> 2        1500   1000 464764. mean  partial dependence
#> # ... with 34 more rows
plot(pd)

The light_profile function offers different ways to specify the evaluation points of the profiles, e.g. by explicitly passing such points.

pd <- light_profile(fls, v = "sqft_living", pd_evaluate_at = seq(1000, 4000, by = 100))
plot(pd)

For discrete variables:

pd <- light_profile(fls, v = "condition")
plot(pd)

Accumulated local effects (ALE)

An approximation of main effects without Ceteris Paribus clause are accumulated local effects profiles [2]. They are based on accumulating local partial dependence slopes.

ale <- light_profile(fls, v = "sqft_living", type = "ale")
ale
#> 
#> I am an object with class(es) light_profile_multi, light_profile, light, list 
#> 
#> Tibbles:
#> 
#>  data 
#> # A tibble: 32 x 5
#>   sqft_living counts   value label type 
#>         <dbl>  <int>   <dbl> <fct> <fct>
#> 1        1500   1000 464764. mean  ale  
#> 2        2500   1000 464764. mean  ale  
#> # ... with 30 more rows
plot(ale)

Interestingly, the effects of the random forest and XGBoost are much steeper (and closer to the linear model) now for large houses compared to the effects from partial depencence.

Note: While equally sized x-breaks are easy to read, quantile binning usually leads to more stable results.

plot(light_profile(fls, v = "sqft_living", type = "ale", cut_type = "quantile"))

In order to calculate ICEs, PDs and ALEs, the following elements need to be available in the (multi-)flashlight: “predict_function”, “model”, “linkinv” and “data”. “data” can also be passed on the fly.

Profiles of predicted values, residuals, and response

We can use the function light_profile not only to create partial dependence profiles but also to get profiles of predicted values (“M plots”), responses or residuals. Additionally, we can either use averages or quartiles as summary statistics.

Average predicted values versus the living area are as follows:

format_y <- function(x) format(x, big.mark = "'", scientific = FALSE)

pvp <- light_profile(fls, v = "sqft_living", type = "predicted", format = "fg", big.mark = "'")
plot(pvp) +
  scale_y_continuous(labels = format_y)

Note the formatting of y values as well as the formatC option format = "fg" and big.mark passed to the constructor of the x labels in order to improve basic appearance. We will recycle some of these settings for the next plots.

Similar the average response profiles (identical for all flashlights, to we only show one of them):

rvp <- light_profile(fl_lm, v = "sqft_living", type = "response", format = "fg") 
plot(rvp) +
  scale_y_continuous(labels = format_y)

Same, but quartiles:

rvp <- light_profile(fl_lm, v = "sqft_living", type = "response", 
                     stats = "quartiles", format = "fg") 
plot(rvp) +
  scale_y_continuous(labels = format_y)

What about residuals? First, we remove the “mean” flashlight by setting it NULL.

fls$mean <- NULL
rvp <- light_profile(fls, v = "sqft_living", type = "residual", 
                     stats = "quartiles", format = "fg") 
plot(rvp) +
  scale_y_continuous(labels = format_y)

While the tree-based models have smaller residuals and medians close to 0, the linear model shows residual curvature that could be captured by representing sqft_living by more parameters.

If unhappy about the “group by” strategy, set swap_dim to TRUE.

plot(rvp, swap_dim = TRUE) +
  scale_y_continuous(labels = format_y)

For less bars, set n_bins in light_profile:

rvp <- light_profile(fls, v = "sqft_living", type = "residual", 
                     stats = "quartiles", format = "fg", n_bins = 5) 
plot(rvp, swap_dim = TRUE) +
  scale_y_continuous(labels = format_y)

In the same way as diverging ICE profiles give a clou of presence of interactions, we can use the option stats = "quartiles" (with pd_center = TRUE) to show divergence of the centered ICE profiles as boxes (predictions at log-scale to suppress interaction-like effects of the retransformation function:

rvp <- light_profile(fls, v = "sqft_living", use_linkinv = FALSE, 
                     stats = "quartiles", pd_center = TRUE) 
plot(rvp)

For prediction profiles, the same elements as for ICE/PDs are required, while for response profiles we need “y”, “linkinv” and “data”. “data” can also be passed on the fly.

Visualizing different types of profiles as “effects” plot

In assessing the model quality, it is often useful to visualize

in the same plot and for each input variable. The flashlight package offers the function light_effects combine such profile plots:

eff <- light_effects(fl_lm, v = "condition") 
p <- plot(eff) +
  scale_y_continuous(labels = format_y)
p

Let’s add counts to see if the gaps between response and predicted profiles are problematic or just due to small samples.

plot_counts(p, eff, alpha = 0.2)

The biggest gaps occur with very rare conditions, so the model looks quite fine.

Note: Due to retransformation from log scale, the response profile is slightly higher than the profile of predicted values. If we would evaluate on the modelled log scale, that gap would vanish.

eff <- light_effects(fl_lm, v = "condition", linkinv = I) 
p <- plot(eff, use = "all") +
  scale_y_continuous(labels = format_y) +
  ggtitle("Effects plot on modelled log scale")
p

Besides adding counts to the figure, representing observed responses as boxplots (no whiskers and outliers in order to avoid too large y scale) might help to judge if there is a problematic misfit.

eff <- light_effects(fl_lm, v = "condition", stats = "quartiles") 
p <- plot(eff, rotate_x = FALSE) +
   scale_y_continuous(labels = format_y)
plot_counts(p, eff, fill = "blue", alpha = 0.2, width = 0.3)

The plot method of light_effects allows to hide certain plot element if it looks too dense.

Variable contribution breakdown for single observation

Besides global effects, we can use light_breakdown to calculate variable impact on one single (log) prediction.

bd <- light_breakdown(fl_lm, new_obs = valid[1, ], v = x, n_max = 1000, seed = 74) 
plot(bd, size = 3)

We have set n_max to 1000 in order to save time. The only variable with positive impact on the prediction is “age”. The other variables have a negative impact compared to the 1000 reference observations. We could limit the figure to the four most relevant variables. In this case, the final bar would show the impact of all other variables together.

bd <- light_breakdown(fl_lm, new_obs = valid[1, ], v = x, n_max = 1000, seed = 74, top_m = 4) 
plot(bd)

Grouped calculations

A key feature of the flashlight package is to support grouped results. You can initialize the (multi-)flashlight with column names of one or many grouping variables or ask for grouped calculations in all major flashlight functions. Plots are adapted accordingly.

fls <- multiflashlight(fls, by = "year")

# Performance
plot(light_performance(fls)) + 
  scale_fill_viridis_d(begin = 0.1, end = 0.9)


# With swapped dimension
plot(light_performance(fls), swap_dim = TRUE) + 
  scale_fill_viridis_d(begin = 0.1, end = 0.9)

  
# Importance
imp <- light_importance(fls, v = x)
plot(imp, top_m = 4)

plot(imp, swap_dim = TRUE)


# Effects: ICE
plot(light_ice(fls, v = "sqft_living", seed = 4345), 
     alpha = 0.8, facet_scales = "free_y") + 
  scale_color_viridis_d(begin = 0.1, end = 0.9) + 
  scale_y_continuous(labels = format_y)


# c-ICE
plot(light_ice(fls, v = "sqft_living", seed = 4345, center = TRUE), 
     alpha = 0.8, facet_scales = "free_y") + 
  scale_color_viridis_d(begin = 0.1, end = 0.9) + 
  scale_y_continuous(labels = format_y)


# Effects: Partial dependence
plot(light_profile(fls, v = "sqft_living"))

plot(light_profile(fls, v = "sqft_living"), swap_dim = TRUE)

plot(light_profile(fls, v = "sqft_living", stats = "quartiles", pd_center = TRUE))


# Effects: ALE
plot(light_profile(fls, v = "sqft_living", type = "ale"))

plot(light_profile(fls, v = "sqft_living", type = "ale"), swap_dim = TRUE)


# Effects: Combined plot (only one flashlight) 
# -> we need to manually pass "by" or update the single flashlight
z <- light_effects(fls, v = "sqft_living", format = "fg", 
                   stats = "quartiles", n_bins = 5, by = NULL)
p <- plot(z) + 
  scale_y_continuous(labels = format_y) +
  coord_cartesian(ylim = c(0, 3e6))
plot_counts(p, z, alpha = 0.2)


# Variable contribution breakdown for single observation (on log-scale)
# -> "by" selects the relevant rows in data/valid
plot(light_breakdown(fl_lm, new_obs = valid[1, ], v = x, top_m = 3))

Working with case weights

In many applications, case weights are involved. All main functions in flashlight deal with them automatically. The only thing you need to do is to pass the column name of the column with case weights when initializing the (multi-)flashlight.

Let’s go through the initial iris example again with (artificial) case weights:

# Add weight info to the flashlight
fl_weighted <- flashlight(fl, w = "Petal.Length", label = "ols weighted")
fls <- multiflashlight(list(fl, fl_weighted))

# Performance: rmse and R-squared
plot(light_performance(fls))

plot(light_performance(fls, by = "Species"))


# Variable importance by drop in rmse
plot(light_importance(fls, by = "Species"))


# ICE profiles for Petal.Width 
# (not affected by weights because nothing is being aggregated)
indices <- seq(10, 150, by = 10)
plot(light_ice(fls, v = "Petal.Width", indices = indices), alpha = 0.2)

plot(light_ice(fls, v = "Petal.Width", by = "Species", indices = indices))


# c-ICE -> lines overlap, no interactions at all
plot(light_ice(fls, v = "Petal.Width", indices = indices, center = TRUE), alpha = 0.2)

plot(light_ice(fls, v = "Petal.Width", by = "Species", indices = indices, center = TRUE))


# Partial dependence profiles for Petal.Width
plot(light_profile(fls, v = "Petal.Width"))

plot(light_profile(fls, v = "Petal.Width", by = "Species"))


# ALE profiles for Petal.Width
plot(light_profile(fls, v = "Petal.Width", type = "ale"))

plot(light_profile(fls, v = "Petal.Width", by = "Species", type = "ale"))


# Observed, predicted, and partial dependence profiles
plot(light_effects(fls, v = "Petal.Width"))

eff <- light_effects(fls, v = "Petal.Width", stats = "quartiles")
plot(eff) %>% 
  plot_counts(eff, alpha = 0.2, fill = "blue")


# Variable contribution breakdown for single observation (on log-scale)
plot(light_breakdown(fls, new_obs = iris[2, ]), size = 2.5)

Binary classification

The flashlight package works for numeric responses including binary targets.

ir <- iris
ir$virginica <- ir$Species == "virginica"

fit <- glm(virginica ~ Sepal.Length + Petal.Width, data = ir, family = binomial)

# Make flashlight
fl <- flashlight(model = fit, data = ir, y = "virginica", label = "lr",
                 metrics = list(logLoss = logLoss, AUC = AUC), 
                 predict_function = function(m, d) predict(m, d, type = "response"))

# Performance: rmse and R-squared
plot(light_performance(fl), fill = "darkred")


# Variable importance by drop in rmse
plot(light_importance(fl, v = c("Sepal.Length", "Petal.Width")), fill = "darkred")


# ICE profiles for Petal.Width
plot(light_ice(fl, v = "Petal.Width"), alpha = 0.4)


# c-ICE profiles for Petal.Width
plot(light_ice(fl, v = "Petal.Width", center = TRUE), alpha = 0.4)


# Partial dependence profiles for Petal.Width
plot(light_profile(fl, v = "Petal.Width"))


# ALE profiles for Petal.Width
plot(light_profile(fl, v = "Petal.Width", type = "ale", cut_type = "quantile"))


# Observed, predicted, and partial dependence profiles
eff <- light_effects(fl, v = "Petal.Width")
plot_counts(plot(eff, use = "all"), eff, alpha = 0.2)


# Variable contribution breakdown for single observation
plot(light_breakdown(fl, new_obs = ir[2, ], v = c("Sepal.Length", "Petal.Width")))

References

[1] Molnar C. (2019). Interpretable machine learning. A Guide for Making Black Box Models Explainable (https://christophm.github.io/interpretable-ml-book/).

[2] Fisher A., Rudin C., Dominici F. (2018). All Models are Wrong but many are Useful: Variable Importance for Black-Box, Proprietary, or Misspecified Prediction Models, using Model Class Reliance. ArXiv (https://arxiv.org/abs/1801.01489).

[3] Goldstein, A. et al. (2015). Peeking inside the black box: Visualizing statistical learning with plots of individual conditional expectation. Journal of Computational and Graphical Statistics, 24:1 (https://doi.org/10.1080/10618600.2014.907095).

[4] Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine. The Annals of Statistics, 29:1189–1232 (https://doi.org/10.1214/aos/1013203451).

[5] Apley D. W. (2016). Visualizing the effects of predictor variables in black box supervised learning models. ArXiv (https://arxiv.org/abs/1612.08468).

[6] Gosiewska A. and Biecek P. (2019). IBREAKDOWN: Uncertainty of model explanations for non-additive predictive models. ArXiv (https://arxiv.org/abs/1903.11420).