ShaRP for classification on large datasets with mixed data types

This example showcases a more complex setting, where we will develop and interpret a classification model using a larger dataset with both categorical and continuous features.

sharp is designed to operate over the unprocessed input space, to ensure every “Frankenstein” point generated to compute feature contributions are plausible. This means that the function producing the scores (or class predictions) should take as input the raw dataset, and every preprocessing step leading to the black box predictions/scores should be included within it.

We will start by downloading the German Credit dataset.

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sharp import ShaRP

Let’s get the data first. We will use the dataset that classifies people described by a set of attributes as good or bad credit risks.

df = fetch_openml(data_id=31, parser="auto")["frame"]
df.head(5)
checking_status duration credit_history purpose credit_amount savings_status employment installment_commitment personal_status other_parties residence_since property_magnitude age other_payment_plans housing existing_credits job num_dependents own_telephone foreign_worker class
0 <0 6 critical/other existing credit radio/tv 1169 no known savings >=7 4 male single none 4 real estate 67 none own 2 skilled 1 yes yes good
1 0<=X<200 48 existing paid radio/tv 5951 <100 1<=X<4 2 female div/dep/mar none 2 real estate 22 none own 1 skilled 1 none yes bad
2 no checking 12 critical/other existing credit education 2096 <100 4<=X<7 2 male single none 3 real estate 49 none own 1 unskilled resident 2 none yes good
3 <0 42 existing paid furniture/equipment 7882 <100 4<=X<7 2 male single guarantor 4 life insurance 45 none for free 1 skilled 2 none yes good
4 <0 24 delayed previously new car 4870 <100 1<=X<4 3 male single none 4 no known property 53 none for free 2 skilled 2 none yes bad


Split X and y (input and target) from df and split train and test:

X = df.drop(columns="class")
y = df["class"]

categorical_features = X.dtypes.apply(
    lambda dtype: isinstance(dtype, pd.CategoricalDtype)
).values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.1, random_state=42
)

Now we will set up model. Here, we will use a pipeline to combine all the preprocessing steps. However, to use sharp, it is also sufficient to pass any function (containing all the preprocessing steps) that takes a numpy array as input and outputs the model’s predictions.

transformer = ColumnTransformer(
    transformers=[
        ("onehot", OneHotEncoder(sparse_output=False), categorical_features),
        ("minmax", MinMaxScaler(), ~categorical_features),
    ],
    remainder="passthrough",
    n_jobs=-1,
)
classifier = LogisticRegression(random_state=42)
model = make_pipeline(transformer, classifier)
model.fit(X_train.values, y_train.values)
Pipeline(steps=[('columntransformer',
                 ColumnTransformer(n_jobs=-1, remainder='passthrough',
                                   transformers=[('onehot',
                                                  OneHotEncoder(sparse_output=False),
                                                  array([ True, False,  True,  True, False,  True,  True, False,  True,
        True, False,  True, False,  True,  True, False,  True, False,
        True,  True])),
                                                 ('minmax', MinMaxScaler(),
                                                  array([False,  True, False, False,  True, False, False,  True, False,
       False,  True, False,  True, False, False,  True, False,  True,
       False, False]))])),
                ('logisticregression', LogisticRegression(random_state=42))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


We can now use sharp to explain our model’s predictions! If we consider the dataset to be too large, we have a few options to reduce computational complexity, such as configuring the n_jobs parameter, setting a value on sample_size, or setting measure=unary.

xai = ShaRP(
    qoi="flip",
    target_function=model.predict,
    measure="unary",
    sample_size=None,
    random_state=42,
    n_jobs=-1,
    verbose=1,
)
xai.fit(X_test)

unary_values = pd.DataFrame(xai.all(X_test), columns=X.columns)
unary_values
  0%|          | 0/100 [00:00<?, ?it/s]
  1%|          | 1/100 [00:00<00:28,  3.48it/s]
  3%|▎         | 3/100 [00:00<00:13,  7.02it/s]
  5%|▌         | 5/100 [00:00<00:10,  8.81it/s]
  7%|▋         | 7/100 [00:00<00:09,  9.70it/s]
  9%|▉         | 9/100 [00:00<00:08, 10.39it/s]
 11%|█         | 11/100 [00:01<00:08, 11.00it/s]
 14%|█▍        | 14/100 [00:01<00:05, 14.51it/s]
 16%|█▌        | 16/100 [00:01<00:05, 14.83it/s]
 19%|█▉        | 19/100 [00:01<00:04, 16.70it/s]
 22%|██▏       | 22/100 [00:01<00:04, 17.98it/s]
 25%|██▌       | 25/100 [00:01<00:03, 20.60it/s]
 28%|██▊       | 28/100 [00:01<00:03, 21.95it/s]
 31%|███       | 31/100 [00:02<00:03, 20.72it/s]
 34%|███▍      | 34/100 [00:02<00:03, 20.99it/s]
 38%|███▊      | 38/100 [00:02<00:02, 22.00it/s]
 41%|████      | 41/100 [00:02<00:02, 22.32it/s]
 44%|████▍     | 44/100 [00:02<00:02, 22.60it/s]
 47%|████▋     | 47/100 [00:02<00:02, 21.20it/s]
 50%|█████     | 50/100 [00:02<00:02, 21.46it/s]
 53%|█████▎    | 53/100 [00:03<00:02, 23.44it/s]
 56%|█████▌    | 56/100 [00:03<00:01, 24.06it/s]
 59%|█████▉    | 59/100 [00:03<00:01, 21.56it/s]
 63%|██████▎   | 63/100 [00:03<00:01, 20.85it/s]
 66%|██████▌   | 66/100 [00:03<00:01, 21.83it/s]
 69%|██████▉   | 69/100 [00:03<00:01, 22.52it/s]
 72%|███████▏  | 72/100 [00:03<00:01, 23.01it/s]
 75%|███████▌  | 75/100 [00:04<00:01, 21.98it/s]
 78%|███████▊  | 78/100 [00:04<00:01, 21.41it/s]
 82%|████████▏ | 82/100 [00:04<00:00, 22.78it/s]
 85%|████████▌ | 85/100 [00:04<00:00, 23.93it/s]
 88%|████████▊ | 88/100 [00:04<00:00, 22.23it/s]
 91%|█████████ | 91/100 [00:04<00:00, 21.02it/s]
 94%|█████████▍| 94/100 [00:04<00:00, 20.80it/s]
 97%|█████████▋| 97/100 [00:05<00:00, 21.77it/s]
100%|██████████| 100/100 [00:05<00:00, 19.47it/s]
checking_status duration credit_history purpose credit_amount savings_status employment installment_commitment personal_status other_parties residence_since property_magnitude age other_payment_plans housing existing_credits job num_dependents own_telephone foreign_worker
0 0.00 0.20 0.08 0.30 0.08 0.00 0.0 0.49 0.07 0.00 0.0 0.16 0.00 0.17 0.2 0.03 0.00 0.0 0.00 0.00
1 0.68 0.09 0.29 0.71 0.29 0.24 0.2 0.34 0.00 0.03 0.0 0.34 0.10 0.00 0.0 0.00 0.02 0.0 0.00 0.02
2 0.46 0.00 0.36 0.65 0.00 0.07 0.0 0.00 0.00 0.03 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.0 0.00 0.02
3 0.32 0.07 0.00 0.24 0.02 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.0 0.00 0.00
4 0.00 0.00 0.00 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.0 0.00 0.00
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
95 0.00 0.00 0.00 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.0 0.00 0.00
96 0.00 0.00 0.00 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.0 0.00 0.00
97 0.00 0.45 0.08 0.00 0.36 0.64 0.8 0.00 0.49 0.05 0.0 0.16 0.24 0.17 0.2 0.35 0.00 0.0 0.67 0.00
98 0.36 0.00 0.29 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.00 0.0 0.00 0.00 0.0 0.00 0.00
99 0.46 0.00 0.00 0.70 0.00 0.24 0.0 0.00 0.51 0.03 0.0 0.00 0.00 0.83 0.1 0.00 0.02 0.0 0.00 0.02

100 rows × 20 columns



Finally, we can plot the mean contributions of each feature:

plt.style.use("seaborn-v0_8-whitegrid")

fig, ax = plt.subplots()
xai.plot.bar(unary_values.mean(), ax=ax)
ax.set_ylim(bottom=0)
ax.tick_params(labelrotation=90)
fig.tight_layout()
plt.show()
plot mixed data

Total running time of the script: (0 minutes 8.865 seconds)

Gallery generated by Sphinx-Gallery