Homework Unit 11: Model Comparisons and Other Explanatory Goals

Author

TA Key

Published

April 11, 2024

Introduction

This file serves as the answer key for the Unit_11 homework. Unit 11 Model Comparisons and Other Explanatory Goals in the course web book contains all materials required for this assignment.

In this assignment, we demonstrate how to perform a model comparison to determine the importance of a pair of focal variables. We use a Bayesian approach to evaluate the effect of that focal variable.


Setup

Handle conflicts

options(conflicts.policy = "depends.ok")
devtools::source_url("https://github.com/jjcurtin/lab_support/blob/main/fun_ml.R?raw=true")
tidymodels_conflictRules()

Load required packages

library(tidyverse) 
library(tidymodels)
library(tidyposterior)
library(DALEX, exclude= "explain")
library(DALEXtra)
library(xfun, include.only = "cache_rds")

Source function scripts

devtools::source_url("https://github.com/jjcurtin/lab_support/blob/main/fun_plots.R?raw=true")
devtools::source_url("https://github.com/jjcurtin/lab_support/blob/main/fun_eda.R?raw=true")

Specify other global settings

theme_set(theme_classic())
options(tibble.width = Inf, dplyr.print_max=Inf)
rerun_setting <- FALSE

Paths

path_data <- "homework/unit_11"

Set up parallel processing

cl <- parallel::makePSOCKcluster(parallel::detectCores(logical = FALSE))
doParallel::registerDoParallel(cl)

Read in data

Read in the student_perf_cln.csv data file as data_all. Class variables as appropriate.

data_all <- read_csv(here::here(path_data, "student_perf_cln.csv"),
                     col_types = cols()) 

data_all |> 
  skim_some()
Data summary
Name data_all
Number of rows 395
Number of columns 31
_______________________
Column type frequency:
character 17
numeric 14
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
school 0 1 2 2 0 2 0
sex 0 1 1 1 0 2 0
address 0 1 5 5 0 2 0
family_size 0 1 14 18 0 2 0
parent_cohabit 0 1 5 8 0 2 0
mother_job 0 1 5 8 0 5 0
father_job 0 1 5 8 0 5 0
school_reason 0 1 4 10 0 4 0
guardian 0 1 5 6 0 3 0
school_support 0 1 2 3 0 2 0
family_support 0 1 2 3 0 2 0
paid 0 1 2 3 0 2 0
activities 0 1 2 3 0 2 0
nursery 0 1 2 3 0 2 0
higher 0 1 2 3 0 2 0
internet 0 1 2 3 0 2 0
romantic 0 1 2 3 0 2 0

Variable type: numeric

skim_variable n_missing complete_rate p0 p100
age 0 1 15 22
mother_educ 0 1 0 4
father_educ 0 1 0 4
travel_time 0 1 1 4
study_time 0 1 1 4
failures 0 1 0 3
family_rel_quality 0 1 1 5
free_time 0 1 1 5
go_out 0 1 1 5
weekday_alcohol 0 1 1 5
weekend_alcohol 0 1 1 5
health 0 1 1 5
absences 0 1 0 75
grade 0 1 0 20
focal_levels <- c("none", "primary", "middle", "secondary", "higher")
data_all <- data_all |> 
  mutate(across(ends_with("_educ"), ~ factor(.x, levels = 0:4, labels = focal_levels)),
         across(where(is.character), as_factor)) 

data_all |> 
  skim_some()
Data summary
Name data_all
Number of rows 395
Number of columns 31
_______________________
Column type frequency:
factor 19
numeric 12
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
school 0 1 FALSE 2 gp: 349, ms: 46
sex 0 1 FALSE 2 f: 208, m: 187
address 0 1 FALSE 2 urb: 307, rur: 88
family_size 0 1 FALSE 2 gre: 281, les: 114
parent_cohabit 0 1 FALSE 2 tog: 354, apa: 41
mother_educ 0 1 FALSE 5 hig: 131, mid: 103, sec: 99, pri: 59
father_educ 0 1 FALSE 5 mid: 115, sec: 100, hig: 96, pri: 82
mother_job 0 1 FALSE 5 oth: 141, ser: 103, at_: 59, tea: 58
father_job 0 1 FALSE 5 oth: 217, ser: 111, tea: 29, at_: 20
school_reason 0 1 FALSE 4 cou: 145, hom: 109, rep: 105, oth: 36
guardian 0 1 FALSE 3 mot: 273, fat: 90, oth: 32
school_support 0 1 FALSE 2 no: 344, yes: 51
family_support 0 1 FALSE 2 yes: 242, no: 153
paid 0 1 FALSE 2 no: 214, yes: 181
activities 0 1 FALSE 2 yes: 201, no: 194
nursery 0 1 FALSE 2 yes: 314, no: 81
higher 0 1 FALSE 2 yes: 375, no: 20
internet 0 1 FALSE 2 yes: 329, no: 66
romantic 0 1 FALSE 2 no: 263, yes: 132

Variable type: numeric

skim_variable n_missing complete_rate p0 p100
age 0 1 15 22
travel_time 0 1 1 4
study_time 0 1 1 4
failures 0 1 0 3
family_rel_quality 0 1 1 5
free_time 0 1 1 5
go_out 0 1 1 5
weekday_alcohol 0 1 1 5
weekend_alcohol 0 1 1 5
health 0 1 1 5
absences 0 1 0 75
grade 0 1 0 20

I am going to look at my variables to see if I will need to make any transformations to them. I am only going to use a subset of my data as an eyeball set to reduce risk of overfitting.

data_eda <- initial_split(data_all, prop = 3/4) |> 
  assessment()
data_eda |> 
  select(where(is.numeric)) |>
  names() |> 
  map(\(name) plot_bar(df = data_eda, x = name)) |> 
  cowplot::plot_grid(plotlist = _, ncol = 4)

data_eda |> 
  select(where(is.factor)) |>
  names() |> 
  map(\(name) plot_bar(df = data_eda, x = name)) |> 
  cowplot::plot_grid(plotlist = _, ncol = 4)

I decided to dichotomize travel_time and failures since there was not much representation among higher categories. I opted to keep the rest of the ordinal variables as numeric. They were relatively normally distributed, and handling them like an unordered categorical variable with dummy features would have created many, many features. Additionally, given that none of these were our focal variables, it seemed simpler to keep the model less complex in the background. However, it’s worth noting that you should check these kind of relationships before deciding how to handle ordinal scores!

data_all <- data_all |> 
  mutate(failures = if_else(failures == 0, "no", "yes"),
              travel_time = if_else(travel_time == 1, "less_than_1_hour", "more_than_1_hour"))

Set up splits

Split data_all into repeated k-fold cross-validation splits using 3 repeats and 10 folds. You do not need to set a new seed. Save your splits as an object named splits

set.seed(123456)
splits <- vfold_cv(data_all, v = 10, repeats = 3)

Build recipes

Recipe 1: Full model

rec_full <- recipe(grade ~ ., data = data_all) |>
  step_scale(all_numeric_predictors()) |> 
  step_dummy(all_nominal_predictors()) |> 
  step_nzv(all_predictors())

Lets make a feature matrix now on data_eda to check our features

rec_full |> 
  prep(data_eda) |>  
  bake(NULL) |> 
  skim_all()
Data summary
Name bake(prep(rec_full, data_…
Number of rows 99
Number of columns 43
_______________________
Column type frequency:
numeric 43
________________________
Group variables None

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 skew kurtosis
age 0 1 13.59 1.00 12.22 13.03 13.85 14.66 17.11 0.53 0.14
travel_time 0 1 1.62 0.79 1.00 1.00 1.00 2.00 4.00 1.15 0.65
study_time 0 1 2.44 1.00 1.19 1.78 2.37 2.37 4.75 0.60 -0.12
failures 0 1 0.39 0.81 0.00 0.00 0.00 0.00 3.00 2.08 3.41
family_rel_quality 0 1 5.14 1.00 1.28 5.11 5.11 5.75 6.39 -0.92 1.62
free_time 0 1 2.98 1.00 0.90 2.69 2.69 3.59 4.49 -0.21 -0.66
go_out 0 1 2.60 1.00 0.87 1.74 2.60 3.47 4.34 0.28 -0.81
weekday_alcohol 0 1 1.57 1.00 1.03 1.03 1.03 2.06 5.14 1.91 2.85
weekend_alcohol 0 1 1.71 1.00 0.77 0.77 1.54 2.31 3.85 0.67 -0.76
health 0 1 2.63 1.00 0.73 2.20 2.94 3.67 3.67 -0.50 -1.02
absences 0 1 0.76 1.00 0.00 0.00 0.29 1.08 4.34 1.71 2.44
grade 0 1 10.00 4.92 0.00 8.00 10.00 13.00 19.00 -0.64 -0.04
school_ms 0 1 0.12 0.33 0.00 0.00 0.00 0.00 1.00 2.29 3.26
sex_m 0 1 0.45 0.50 0.00 0.00 0.00 1.00 1.00 0.18 -1.99
address_rural 0 1 0.19 0.40 0.00 0.00 0.00 0.00 1.00 1.54 0.38
family_size_less_or_equal_to_3 0 1 0.31 0.47 0.00 0.00 0.00 1.00 1.00 0.79 -1.38
parent_cohabit_together 0 1 0.89 0.32 0.00 1.00 1.00 1.00 1.00 -2.44 3.98
mother_educ_primary 0 1 0.13 0.34 0.00 0.00 0.00 0.00 1.00 2.15 2.65
mother_educ_middle 0 1 0.31 0.47 0.00 0.00 0.00 1.00 1.00 0.79 -1.38
mother_educ_secondary 0 1 0.30 0.46 0.00 0.00 0.00 1.00 1.00 0.84 -1.30
mother_educ_higher 0 1 0.24 0.43 0.00 0.00 0.00 0.00 1.00 1.18 -0.60
father_educ_primary 0 1 0.21 0.41 0.00 0.00 0.00 0.00 1.00 1.39 -0.08
father_educ_middle 0 1 0.33 0.47 0.00 0.00 0.00 1.00 1.00 0.70 -1.53
father_educ_secondary 0 1 0.26 0.44 0.00 0.00 0.00 1.00 1.00 1.06 -0.88
father_educ_higher 0 1 0.19 0.40 0.00 0.00 0.00 0.00 1.00 1.54 0.38
mother_job_other 0 1 0.36 0.48 0.00 0.00 0.00 1.00 1.00 0.56 -1.71
mother_job_services 0 1 0.30 0.46 0.00 0.00 0.00 1.00 1.00 0.84 -1.30
mother_job_teacher 0 1 0.13 0.34 0.00 0.00 0.00 0.00 1.00 2.15 2.65
father_job_other 0 1 0.57 0.50 0.00 0.00 1.00 1.00 1.00 -0.26 -1.95
father_job_services 0 1 0.26 0.44 0.00 0.00 0.00 1.00 1.00 1.06 -0.88
father_job_at_home 0 1 0.05 0.22 0.00 0.00 0.00 0.00 1.00 4.04 14.49
school_reason_other 0 1 0.09 0.29 0.00 0.00 0.00 0.00 1.00 2.80 5.92
school_reason_home 0 1 0.26 0.44 0.00 0.00 0.00 1.00 1.00 1.06 -0.88
school_reason_reputation 0 1 0.23 0.42 0.00 0.00 0.00 0.00 1.00 1.25 -0.45
guardian_father 0 1 0.24 0.43 0.00 0.00 0.00 0.00 1.00 1.18 -0.60
guardian_other 0 1 0.06 0.24 0.00 0.00 0.00 0.00 1.00 3.63 11.27
school_support_no 0 1 0.86 0.35 0.00 1.00 1.00 1.00 1.00 -2.03 2.13
family_support_yes 0 1 0.60 0.49 0.00 0.00 1.00 1.00 1.00 -0.39 -1.87
paid_yes 0 1 0.45 0.50 0.00 0.00 0.00 1.00 1.00 0.18 -1.99
activities_yes 0 1 0.44 0.50 0.00 0.00 0.00 1.00 1.00 0.22 -1.97
nursery_no 0 1 0.21 0.41 0.00 0.00 0.00 0.00 1.00 1.39 -0.08
internet_yes 0 1 0.83 0.38 0.00 1.00 1.00 1.00 1.00 -1.71 0.95
romantic_yes 0 1 0.33 0.47 0.00 0.00 0.00 1.00 1.00 0.70 -1.53

Things I pay attention to when looking at my feature matrix and skim_all() output:

  • Are all my predictors numeric? Even categorical variables should be numeric at this point because they should have been dummy-coded.

  • Are there any missing values?

  • Is the sample size what I would expect it to be for the data set? Remember this is roughly 75% of data_all

Recipe 2: Compact model

Hopefully there wouldn’t be any issues with our compact recipe given that it builds directly from rec_full, but it’s worth skimming just in case.

rec_compact <- rec_full |>
  step_rm(starts_with("mother_educ"), starts_with("father_educ"))

rec_compact |> 
  prep(data_eda) |>  
  bake(NULL) |> 
  skim_all()
Data summary
Name bake(prep(rec_compact, da…
Number of rows 99
Number of columns 35
_______________________
Column type frequency:
numeric 35
________________________
Group variables None

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 skew kurtosis
age 0 1 13.59 1.00 12.22 13.03 13.85 14.66 17.11 0.53 0.14
travel_time 0 1 1.62 0.79 1.00 1.00 1.00 2.00 4.00 1.15 0.65
study_time 0 1 2.44 1.00 1.19 1.78 2.37 2.37 4.75 0.60 -0.12
failures 0 1 0.39 0.81 0.00 0.00 0.00 0.00 3.00 2.08 3.41
family_rel_quality 0 1 5.14 1.00 1.28 5.11 5.11 5.75 6.39 -0.92 1.62
free_time 0 1 2.98 1.00 0.90 2.69 2.69 3.59 4.49 -0.21 -0.66
go_out 0 1 2.60 1.00 0.87 1.74 2.60 3.47 4.34 0.28 -0.81
weekday_alcohol 0 1 1.57 1.00 1.03 1.03 1.03 2.06 5.14 1.91 2.85
weekend_alcohol 0 1 1.71 1.00 0.77 0.77 1.54 2.31 3.85 0.67 -0.76
health 0 1 2.63 1.00 0.73 2.20 2.94 3.67 3.67 -0.50 -1.02
absences 0 1 0.76 1.00 0.00 0.00 0.29 1.08 4.34 1.71 2.44
grade 0 1 10.00 4.92 0.00 8.00 10.00 13.00 19.00 -0.64 -0.04
school_ms 0 1 0.12 0.33 0.00 0.00 0.00 0.00 1.00 2.29 3.26
sex_m 0 1 0.45 0.50 0.00 0.00 0.00 1.00 1.00 0.18 -1.99
address_rural 0 1 0.19 0.40 0.00 0.00 0.00 0.00 1.00 1.54 0.38
family_size_less_or_equal_to_3 0 1 0.31 0.47 0.00 0.00 0.00 1.00 1.00 0.79 -1.38
parent_cohabit_together 0 1 0.89 0.32 0.00 1.00 1.00 1.00 1.00 -2.44 3.98
mother_job_other 0 1 0.36 0.48 0.00 0.00 0.00 1.00 1.00 0.56 -1.71
mother_job_services 0 1 0.30 0.46 0.00 0.00 0.00 1.00 1.00 0.84 -1.30
mother_job_teacher 0 1 0.13 0.34 0.00 0.00 0.00 0.00 1.00 2.15 2.65
father_job_other 0 1 0.57 0.50 0.00 0.00 1.00 1.00 1.00 -0.26 -1.95
father_job_services 0 1 0.26 0.44 0.00 0.00 0.00 1.00 1.00 1.06 -0.88
father_job_at_home 0 1 0.05 0.22 0.00 0.00 0.00 0.00 1.00 4.04 14.49
school_reason_other 0 1 0.09 0.29 0.00 0.00 0.00 0.00 1.00 2.80 5.92
school_reason_home 0 1 0.26 0.44 0.00 0.00 0.00 1.00 1.00 1.06 -0.88
school_reason_reputation 0 1 0.23 0.42 0.00 0.00 0.00 0.00 1.00 1.25 -0.45
guardian_father 0 1 0.24 0.43 0.00 0.00 0.00 0.00 1.00 1.18 -0.60
guardian_other 0 1 0.06 0.24 0.00 0.00 0.00 0.00 1.00 3.63 11.27
school_support_no 0 1 0.86 0.35 0.00 1.00 1.00 1.00 1.00 -2.03 2.13
family_support_yes 0 1 0.60 0.49 0.00 0.00 1.00 1.00 1.00 -0.39 -1.87
paid_yes 0 1 0.45 0.50 0.00 0.00 0.00 1.00 1.00 0.18 -1.99
activities_yes 0 1 0.44 0.50 0.00 0.00 0.00 1.00 1.00 0.22 -1.97
nursery_no 0 1 0.21 0.41 0.00 0.00 0.00 0.00 1.00 1.39 -0.08
internet_yes 0 1 0.83 0.38 0.00 1.00 1.00 1.00 1.00 -1.71 0.95
romantic_yes 0 1 0.33 0.47 0.00 0.00 0.00 1.00 1.00 0.70 -1.53

Fit models

Set up a hyperparameter tuning grid

tune_grid <- expand_grid(penalty = exp(seq(-8, .4, length.out = 500)),
                         mixture = seq(0, 1, length.out = 6))

Fit the full model

Use rec_full, splits, and tune_grid to fit GLM’s across folds. Use R squared as your metric. Save your model fits as fits_full.

fits_full <- linear_reg(penalty = tune(),
                        mixture = tune()) |>
  set_engine("glmnet") |>
  set_mode("regression") |> 
  tune_grid(preprocessor = rec_full,
                resamples = splits,
                grid = tune_grid,
                metrics = metric_set(rsq))

Make sure you considered a good range of hyperparameter values

plot_hyperparameters(fits_full, hp1 = "penalty", hp2 = "mixture", metric = "rsq")

Select best model configuration

select_best(fits_full)
# A tibble: 1 × 3
  penalty mixture .config                
    <dbl>   <dbl> <chr>                  
1    1.49       0 Preprocessor1_Model0500

Print the mean R squared of the best model configuration across the 30 held-out folds.

collect_metrics(fits_full, summarize = TRUE) |> 
  filter(.config == select_best(fits_full)$.config)
# A tibble: 1 × 8
  penalty mixture .metric .estimator  mean     n std_err .config                
    <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                  
1    1.49       0 rsq     standard   0.159    30  0.0180 Preprocessor1_Model0500

Fit the compact model

Use rec_compact, splits, and tune_grid to fit GLM’s across folds. Use R squared as your metric. Save your model fits as fits_compact.

fits_compact <- linear_reg(penalty = tune(),
                        mixture = tune()) |>
  set_engine("glmnet") |>
  set_mode("regression") |> 
  tune_grid(preprocessor = rec_compact,
                resamples = splits,
                grid = tune_grid,
                metrics = metric_set(rsq))

Select best model configuration

select_best(fits_compact)
# A tibble: 1 × 3
  penalty mixture .config                
    <dbl>   <dbl> <chr>                  
1    1.49       0 Preprocessor1_Model0500

Print the mean R squared of the best model configuration across the 30 held-out folds.

collect_metrics(fits_compact, summarize = TRUE) |> 
  filter(.config == select_best(fits_full)$.config)
# A tibble: 1 × 8
  penalty mixture .metric .estimator  mean     n std_err .config                
    <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                  
1    1.49       0 rsq     standard   0.159    30  0.0162 Preprocessor1_Model0500

Model comparison with the Bayesian approach

You will now compare your models using the Bayesian parameter estimation

Gather performance estimates

First, make a data frame containing the 30 performance estimates from held out folds for your full and compact model. Hint you can use the code above but change summarize = TRUE to summarize = FALSE. Filter down so you only have the following variables (id, id2, .estimate). Rename .estimate to full or compact before joining your estimates.

rsq_full <-
  collect_metrics(fits_full, summarize = FALSE) |>
  filter(.config == select_best(fits_full)$.config) |> 
  select(id, id2, full = .estimate)

rsq_compact <-
     collect_metrics(fits_compact, summarize = FALSE) |>
     filter(.config == select_best(fits_full)$.config) |> 
     select(id, id2, compact = .estimate)

rsq_all <- rsq_full |> 
  full_join(rsq_compact, by = c("id", "id2")) |> 
  print()
# A tibble: 30 × 4
   id      id2            full  compact
   <chr>   <chr>         <dbl>    <dbl>
 1 Repeat1 Fold01 0.0921       0.0604  
 2 Repeat1 Fold02 0.0654       0.0905  
 3 Repeat1 Fold03 0.0531       0.0432  
 4 Repeat1 Fold04 0.233        0.256   
 5 Repeat1 Fold05 0.00167      0.0159  
 6 Repeat1 Fold06 0.283        0.254   
 7 Repeat1 Fold07 0.261        0.236   
 8 Repeat1 Fold08 0.374        0.354   
 9 Repeat1 Fold09 0.188        0.216   
10 Repeat1 Fold10 0.144        0.201   
11 Repeat2 Fold01 0.235        0.212   
12 Repeat2 Fold02 0.220        0.184   
13 Repeat2 Fold03 0.0386       0.0335  
14 Repeat2 Fold04 0.134        0.167   
15 Repeat2 Fold05 0.229        0.246   
16 Repeat2 Fold06 0.202        0.172   
17 Repeat2 Fold07 0.182        0.178   
18 Repeat2 Fold08 0.0969       0.144   
19 Repeat2 Fold09 0.0000000330 0.000586
20 Repeat2 Fold10 0.380        0.320   
21 Repeat3 Fold01 0.162        0.190   
22 Repeat3 Fold02 0.0957       0.111   
23 Repeat3 Fold03 0.126        0.154   
24 Repeat3 Fold04 0.0672       0.0776  
25 Repeat3 Fold05 0.0548       0.0705  
26 Repeat3 Fold06 0.144        0.151   
27 Repeat3 Fold07 0.0778       0.0651  
28 Repeat3 Fold08 0.171        0.147   
29 Repeat3 Fold09 0.274        0.252   
30 Repeat3 Fold10 0.185        0.163   

Posterior probabilities

Next, derive the posterior probabilities for the R squared of each of these two models. Because we used repeated k-fold we need to add in two random intercepts to our model - one for repeats and one for fold

set.seed(102030)

pp <- tidyposterior::perf_mod(rsq_all, 
                    formula = statistic ~ model + (1 | id2/id),
                    iter = 3000, chains = 4,  
                    hetero_var = TRUE,
                    adapt_delta = 0.999)  

SAMPLING FOR MODEL 'continuous' NOW (CHAIN 1).
Chain 1: 
Chain 1: Gradient evaluation took 0.001322 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 13.22 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1: 
Chain 1: 
Chain 1: Iteration:    1 / 3000 [  0%]  (Warmup)
Chain 1: Iteration:  300 / 3000 [ 10%]  (Warmup)
Chain 1: Iteration:  600 / 3000 [ 20%]  (Warmup)
Chain 1: Iteration:  900 / 3000 [ 30%]  (Warmup)
Chain 1: Iteration: 1200 / 3000 [ 40%]  (Warmup)
Chain 1: Iteration: 1500 / 3000 [ 50%]  (Warmup)
Chain 1: Iteration: 1501 / 3000 [ 50%]  (Sampling)
Chain 1: Iteration: 1800 / 3000 [ 60%]  (Sampling)
Chain 1: Iteration: 2100 / 3000 [ 70%]  (Sampling)
Chain 1: Iteration: 2400 / 3000 [ 80%]  (Sampling)
Chain 1: Iteration: 2700 / 3000 [ 90%]  (Sampling)
Chain 1: Iteration: 3000 / 3000 [100%]  (Sampling)
Chain 1: 
Chain 1:  Elapsed Time: 5.756 seconds (Warm-up)
Chain 1:                4.903 seconds (Sampling)
Chain 1:                10.659 seconds (Total)
Chain 1: 

SAMPLING FOR MODEL 'continuous' NOW (CHAIN 2).
Chain 2: 
Chain 2: Gradient evaluation took 2.1e-05 seconds
Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.21 seconds.
Chain 2: Adjust your expectations accordingly!
Chain 2: 
Chain 2: 
Chain 2: Iteration:    1 / 3000 [  0%]  (Warmup)
Chain 2: Iteration:  300 / 3000 [ 10%]  (Warmup)
Chain 2: Iteration:  600 / 3000 [ 20%]  (Warmup)
Chain 2: Iteration:  900 / 3000 [ 30%]  (Warmup)
Chain 2: Iteration: 1200 / 3000 [ 40%]  (Warmup)
Chain 2: Iteration: 1500 / 3000 [ 50%]  (Warmup)
Chain 2: Iteration: 1501 / 3000 [ 50%]  (Sampling)
Chain 2: Iteration: 1800 / 3000 [ 60%]  (Sampling)
Chain 2: Iteration: 2100 / 3000 [ 70%]  (Sampling)
Chain 2: Iteration: 2400 / 3000 [ 80%]  (Sampling)
Chain 2: Iteration: 2700 / 3000 [ 90%]  (Sampling)
Chain 2: Iteration: 3000 / 3000 [100%]  (Sampling)
Chain 2: 
Chain 2:  Elapsed Time: 6.126 seconds (Warm-up)
Chain 2:                3.153 seconds (Sampling)
Chain 2:                9.279 seconds (Total)
Chain 2: 

SAMPLING FOR MODEL 'continuous' NOW (CHAIN 3).
Chain 3: 
Chain 3: Gradient evaluation took 2.1e-05 seconds
Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 0.21 seconds.
Chain 3: Adjust your expectations accordingly!
Chain 3: 
Chain 3: 
Chain 3: Iteration:    1 / 3000 [  0%]  (Warmup)
Chain 3: Iteration:  300 / 3000 [ 10%]  (Warmup)
Chain 3: Iteration:  600 / 3000 [ 20%]  (Warmup)
Chain 3: Iteration:  900 / 3000 [ 30%]  (Warmup)
Chain 3: Iteration: 1200 / 3000 [ 40%]  (Warmup)
Chain 3: Iteration: 1500 / 3000 [ 50%]  (Warmup)
Chain 3: Iteration: 1501 / 3000 [ 50%]  (Sampling)
Chain 3: Iteration: 1800 / 3000 [ 60%]  (Sampling)
Chain 3: Iteration: 2100 / 3000 [ 70%]  (Sampling)
Chain 3: Iteration: 2400 / 3000 [ 80%]  (Sampling)
Chain 3: Iteration: 2700 / 3000 [ 90%]  (Sampling)
Chain 3: Iteration: 3000 / 3000 [100%]  (Sampling)
Chain 3: 
Chain 3:  Elapsed Time: 5.662 seconds (Warm-up)
Chain 3:                2.747 seconds (Sampling)
Chain 3:                8.409 seconds (Total)
Chain 3: 

SAMPLING FOR MODEL 'continuous' NOW (CHAIN 4).
Chain 4: 
Chain 4: Gradient evaluation took 2.1e-05 seconds
Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 0.21 seconds.
Chain 4: Adjust your expectations accordingly!
Chain 4: 
Chain 4: 
Chain 4: Iteration:    1 / 3000 [  0%]  (Warmup)
Chain 4: Iteration:  300 / 3000 [ 10%]  (Warmup)
Chain 4: Iteration:  600 / 3000 [ 20%]  (Warmup)
Chain 4: Iteration:  900 / 3000 [ 30%]  (Warmup)
Chain 4: Iteration: 1200 / 3000 [ 40%]  (Warmup)
Chain 4: Iteration: 1500 / 3000 [ 50%]  (Warmup)
Chain 4: Iteration: 1501 / 3000 [ 50%]  (Sampling)
Chain 4: Iteration: 1800 / 3000 [ 60%]  (Sampling)
Chain 4: Iteration: 2100 / 3000 [ 70%]  (Sampling)
Chain 4: Iteration: 2400 / 3000 [ 80%]  (Sampling)
Chain 4: Iteration: 2700 / 3000 [ 90%]  (Sampling)
Chain 4: Iteration: 3000 / 3000 [100%]  (Sampling)
Chain 4: 
Chain 4:  Elapsed Time: 5.182 seconds (Warm-up)
Chain 4:                3.208 seconds (Sampling)
Chain 4:                8.39 seconds (Total)
Chain 4: 

Remember to always check your warnings in case they don’t print! You can address divergence warnings by adjusting adapt_delta, iter, and chains values. If you do get a warning be sure you clear it before re-running the code. You can use assign("last.warning", NULL, envir = baseenv())

warnings()

Graph positerior probabilities

We can look at our posterior probability ranges using the following code.

pp_tidy <- pp |> 
  tidy(seed = 123) 

pp_tidy |> 
  group_by(model) |> 
  summarize(mean = mean(posterior),
            lower = quantile(posterior, probs = .025), 
            upper = quantile(posterior, probs = .975)) |> 
  mutate(model = factor(model, levels = c("full", "compact"))) |> 
  arrange(model)
# A tibble: 2 × 4
  model    mean lower upper
  <fct>   <dbl> <dbl> <dbl>
1 full    0.159 0.124 0.194
2 compact 0.159 0.124 0.194

Display your posterior probabilities using both a density plot and a histogram. Choose plots that would be the most useful to display your results to a collaborator.

pp_tidy |> 
  ggplot() + 
  geom_density(aes(x = posterior, color = model)) +
  xlab("R Squared")

pp_tidy |> 
  ggplot() + 
  geom_histogram(aes(x = posterior, fill = model), color = "white", alpha = 0.4,
                 bins = 50, position = "identity") +
  xlab("R Squared")

Since we get a lot of overlap here, its probably more useful to use the faceted plots. We can also calculate 95% credible intervals (HDI) and add those to our plots!

ci <- pp_tidy |> 
  summary() |> 
  mutate(y = 450)

pp_tidy |> 
  ggplot(aes(x = posterior)) + 
  geom_histogram(aes(x = posterior, fill = model), color = "white", bins = 50) +  
  geom_segment(mapping = aes(y = y+50, yend = y-50, x = mean, xend = mean,
                           color = model),
               data = ci) +
  geom_segment(mapping = aes(y = y, yend = y, x = lower, xend = upper, color = model),
                data = ci) +
  facet_wrap(~ model, ncol = 1) +
  theme(legend.position = "none") +
  ylab("Count") +
  xlab("R Squared")

Determine if the full model has better performance

Calculate the probability that the full model is performing with lower error than the compact model.

pp_contrast <- pp |> 
  contrast_models(seed = 12) |>
  summary(size = .01) |> 
  glimpse()
Rows: 1
Columns: 9
$ contrast    <chr> "full vs compact"
$ probability <dbl> 0.5098333
$ mean        <dbl> 0.0002022965
$ lower       <dbl> -0.009144342
$ upper       <dbl> 0.009586029
$ size        <dbl> 0.01
$ pract_neg   <dbl> 0.037
$ pract_equiv <dbl> 0.92
$ pract_pos   <dbl> 0.043
pp |> 
  contrast_models(seed = 12) |> 
  ggplot(aes(x = difference)) + 
  geom_histogram(bins = 50, color = "white", fill = "light grey")+
  geom_vline(aes(xintercept = -.01), linetype = "dashed") + 
  geom_vline(aes(xintercept = .01), linetype = "dashed")

Here is what we find:

  • The mean increase in R Squared is 0. This means that on average, the full model is yielding R Squared values that are 0 lower than the compact model.

  • Our 95% HDI is -0.01 to 0.01. This means we are 95% confident that the true difference in R Squared ranges from -0.01 to 0.01.

  • The full model is not meaningfully superior than the compact model. The probability that our full model has better performance than the compact model with at least .01 difference in R Squared was 51%. Our practical equivalence score is 0.92, indicating that the proportion of credible values that fall within our ROPE distribution is 92% (i.e., the probability that these models do not differ meaningfully is 92%). Therefore the compact model is superior, but likely not meaningfully different.

Feature importance

You will now look at feature importance in your full model using Shapely Values.

Prep data

Create a feature matrix from your full data set and fit a model on all of the data

feat_all <- rec_full |> 
  prep(data_all) |> 
  bake(data_all) 

fit_all_data <- linear_reg(penalty = select_best(fits_full)$penalty,
                           mixture = select_best(fits_full)$mixture) |>
  set_engine("glmnet") |>
  set_mode("regression") |> 
  fit(grade ~ ., data = feat_all)

Pull out your features and outcome to be used in calculating feature importance

x <- feat_all |> select(-grade) |>  glimpse()
Rows: 395
Columns: 44
$ age                            <dbl> 14.10611, 13.32244, 11.75509, 11.75509,…
$ study_time                     <dbl> 2.383108, 2.383108, 2.383108, 3.574661,…
$ family_rel_quality             <dbl> 4.461007, 5.576258, 4.461007, 3.345755,…
$ free_time                      <dbl> 3.003418, 3.003418, 3.003418, 2.002279,…
$ go_out                         <dbl> 3.5929924, 2.6947443, 1.7964962, 1.7964…
$ weekday_alcohol                <dbl> 1.122660, 1.122660, 2.245321, 1.122660,…
$ weekend_alcohol                <dbl> 0.7764599, 0.7764599, 2.3293796, 0.7764…
$ health                         <dbl> 2.1578024, 2.1578024, 2.1578024, 3.5963…
$ absences                       <dbl> 0.7497099, 0.4998066, 1.2495165, 0.2499…
$ school_ms                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ sex_m                          <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, …
$ address_rural                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ family_size_less_or_equal_to_3 <dbl> 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, …
$ parent_cohabit_together        <dbl> 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, …
$ mother_educ_primary            <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ mother_educ_middle             <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, …
$ mother_educ_secondary          <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, …
$ mother_educ_higher             <dbl> 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, …
$ father_educ_primary            <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
$ father_educ_middle             <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, …
$ father_educ_secondary          <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, …
$ father_educ_higher             <dbl> 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, …
$ mother_job_health              <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
$ mother_job_other               <dbl> 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, …
$ mother_job_services            <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, …
$ mother_job_teacher             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
$ father_job_other               <dbl> 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, …
$ father_job_services            <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
$ father_job_at_home             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ school_reason_other            <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ school_reason_home             <dbl> 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, …
$ school_reason_reputation       <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, …
$ guardian_father                <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
$ guardian_other                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ travel_time_more_than_1_hour   <dbl> 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
$ failures_yes                   <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ school_support_no              <dbl> 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, …
$ family_support_yes             <dbl> 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, …
$ paid_yes                       <dbl> 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, …
$ activities_yes                 <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, …
$ nursery_no                     <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ higher_no                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ internet_yes                   <dbl> 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, …
$ romantic_yes                   <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
y <- feat_all |> pull(grade)

Use the following code to define your predictor function (predict_wrapper) and explainer object (explain_full)

predict_wrapper <- function(model, newdata) {
  predict(model, newdata) |>  
    pull(.pred)
}

explain_full <- explain_tidymodels(fit_all_data, 
                                   data = x, 
                                   y = y, 
                                   predict_function = predict_wrapper)
Preparation of a new explainer is initiated
  -> model label       :  model_fit  (  default  )
  -> data              :  395  rows  44  cols 
  -> data              :  tibble converted into a data.frame 
  -> target variable   :  395  values 
  -> predict function  :  predict_function 
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package parsnip , ver. 1.1.1 , task regression (  default  ) 
  -> predicted values  :  numerical, min =  4.661688 , mean =  10.41519 , max =  15.13141  
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  -12.60379 , mean =  -8.085694e-15 , max =  8.618674  
  A new explainer has been created!  

Calculate Shapely Values for a single participant

First, we will look at shapely values for a single participant. For this example, lets look at the last participant in our data set.

First, we print the raw feature values for the last participant.

obs_num <- nrow(feat_all) 

x1 <- x |> 
  slice(obs_num) |> 
  glimpse() 
Rows: 1
Columns: 44
$ age                            <dbl> 14.88978
$ study_time                     <dbl> 1.191554
$ family_rel_quality             <dbl> 3.345755
$ free_time                      <dbl> 2.002279
$ go_out                         <dbl> 2.694744
$ weekday_alcohol                <dbl> 3.367981
$ weekend_alcohol                <dbl> 2.32938
$ health                         <dbl> 3.596337
$ absences                       <dbl> 0.6247582
$ school_ms                      <dbl> 1
$ sex_m                          <dbl> 1
$ address_rural                  <dbl> 0
$ family_size_less_or_equal_to_3 <dbl> 1
$ parent_cohabit_together        <dbl> 1
$ mother_educ_primary            <dbl> 1
$ mother_educ_middle             <dbl> 0
$ mother_educ_secondary          <dbl> 0
$ mother_educ_higher             <dbl> 0
$ father_educ_primary            <dbl> 1
$ father_educ_middle             <dbl> 0
$ father_educ_secondary          <dbl> 0
$ father_educ_higher             <dbl> 0
$ mother_job_health              <dbl> 0
$ mother_job_other               <dbl> 1
$ mother_job_services            <dbl> 0
$ mother_job_teacher             <dbl> 0
$ father_job_other               <dbl> 0
$ father_job_services            <dbl> 0
$ father_job_at_home             <dbl> 1
$ school_reason_other            <dbl> 0
$ school_reason_home             <dbl> 0
$ school_reason_reputation       <dbl> 0
$ guardian_father                <dbl> 1
$ guardian_other                 <dbl> 0
$ travel_time_more_than_1_hour   <dbl> 0
$ failures_yes                   <dbl> 0
$ school_support_no              <dbl> 1
$ family_support_yes             <dbl> 0
$ paid_yes                       <dbl> 0
$ activities_yes                 <dbl> 0
$ nursery_no                     <dbl> 0
$ higher_no                      <dbl> 0
$ internet_yes                   <dbl> 1
$ romantic_yes                   <dbl> 0

Generate Shapely Values for this participant and output the raw results. We are also going to cache this calculation using cache_rds since it takes a little while to run.

sv <- predict_parts(explain_full, 
                  new_observation = x1,
                  type = "shap",
                  B = 25)

Plot the Shapely Values for this participant

plot(sv)

Calculate mean absolute Shapely Values across all participants

Now we will look at feature importance across all participants. Since this process is extremely computationally intensive, we are going to select a random set of participants to demonstrate this on.

Use the function slice_sample() with n set to 20 to get a random subset of observations from x.

set.seed(101)
x_sample <- x |> 
  slice_sample(n = 20) |> 
  glimpse()
Rows: 20
Columns: 44
$ age                            <dbl> 13.32244, 14.88978, 11.75509, 12.53877,…
$ study_time                     <dbl> 3.574661, 2.383108, 4.766215, 1.191554,…
$ family_rel_quality             <dbl> 5.576258, 4.461007, 4.461007, 4.461007,…
$ free_time                      <dbl> 4.004557, 5.005696, 3.003418, 3.003418,…
$ go_out                         <dbl> 3.592992, 1.796496, 3.592992, 1.796496,…
$ weekday_alcohol                <dbl> 1.122660, 2.245321, 1.122660, 1.122660,…
$ weekend_alcohol                <dbl> 2.3293796, 1.5529197, 0.7764599, 3.1058…
$ health                         <dbl> 2.8770699, 2.8770699, 2.8770699, 3.5963…
$ absences                       <dbl> 0.8746615, 0.3748549, 0.7497099, 0.7497…
$ school_ms                      <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, …
$ sex_m                          <dbl> 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, …
$ address_rural                  <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, …
$ family_size_less_or_equal_to_3 <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
$ parent_cohabit_together        <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, …
$ mother_educ_primary            <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
$ mother_educ_middle             <dbl> 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, …
$ mother_educ_secondary          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, …
$ mother_educ_higher             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
$ father_educ_primary            <dbl> 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, …
$ father_educ_middle             <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
$ father_educ_secondary          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
$ father_educ_higher             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
$ mother_job_health              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ mother_job_other               <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
$ mother_job_services            <dbl> 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, …
$ mother_job_teacher             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ father_job_other               <dbl> 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, …
$ father_job_services            <dbl> 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
$ father_job_at_home             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ school_reason_other            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ school_reason_home             <dbl> 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, …
$ school_reason_reputation       <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
$ guardian_father                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
$ guardian_other                 <dbl> 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, …
$ travel_time_more_than_1_hour   <dbl> 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, …
$ failures_yes                   <dbl> 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, …
$ school_support_no              <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
$ family_support_yes             <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, …
$ paid_yes                       <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, …
$ activities_yes                 <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
$ nursery_no                     <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, …
$ higher_no                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
$ internet_yes                   <dbl> 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, …
$ romantic_yes                   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, …

Calculate Shapely Values for each observation (in your reduced feature set) and glimpse() the resulting tibble

get_shaps <- function(df1){
  predict_parts(explain_full, 
                new_observation = df1,
                type = "shap",
                B = 25) |> 
    filter(B == 0) |> 
    select(variable_name, variable_value, contribution) |> 
    as_tibble()
}

tictoc::tic()
local_shaps <- cache_rds(
  expr = {
    x_sample |>
      mutate(id = row_number()) |>
      nest(.by = id, .key = "dfs") |>  
      mutate(shapleys = map(dfs, \(df1) get_shaps(df1))) |>
      select(-dfs) |>
      unnest(shapleys)
  },
  rerun = rerun_setting,
  dir = "cache/",
  file = "shaps_local")
tictoc::toc()
0.01 sec elapsed
local_shaps |> 
  glimpse()
Rows: 880
Columns: 4
$ id             <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ variable_name  <chr> "absences", "activities_yes", "address_rural", "age", "…
$ variable_value <chr> "0.8747", "1", "0", "13.32", "0", "5.576", "0", "1", "1…
$ contribution   <dbl> 0.054730727, -0.093000006, 0.084335484, -0.062996325, 0…

Plot the mean absolute Shapely Values across participants

local_shaps |>
  mutate(contribution = abs(contribution)) |>
  group_by(variable_name) |>
  summarize(mean_shap = mean(contribution)) |>
  arrange(desc(mean_shap)) |>
  mutate(variable_name = factor(variable_name),
         variable_name = fct_reorder(variable_name, mean_shap)) |>
  ggplot(aes(x = variable_name, y = mean_shap)) +
  geom_point() +
  coord_flip()

Top takeaways from this plot:
We can see that across a subset of participants, previous class failures and being male are the variables with the most impact on grade prediction. These variables each contribute to an average change of around 1.0 and .47 from the mean predicted grade, respectively. If we look at our plot of Shapley values for participant 395 we can see which direction these features went in for this individual. For example, with this participant can see that not having previous failures (failures_yes = 0) and being male were influential features that had a positive effect on performance. However, the amount of importance and even direction may vary quite a bit based on the individual. Sina plots can be helpful to visualize this variance! One other thing I notice when looking at our focal variables, it appears that mother’s education level is relatively more important than father’s education. It could be interesting to explore that more!