import re
import pandas as pd
import numpy as np
from sharp.utils._utils import _optional_import
blue_rgb = np.array([0, 0.54337757, 0.98337906])
light_blue_rgb = np.array([127.0, 196, 252]) / 255
red_rgb = np.array([1, 0, 0.31796406])
light_red_rgb = np.array([255.0, 127, 167]) / 255
def format_value(s, format_str):
"""Strips trailing zeros and uses a unicode minus sign."""
if not issubclass(type(s), str):
s = format_str % s
s = re.sub(r"\.?0+$", "", s)
if s[0] == "-":
s = "\u2212" + s[1:]
return s
def _waterfall(shap_values, max_display=10, show=False): # noqa
matplotlib = _optional_import("matplotlib")
plt = _optional_import("matplotlib.pyplot")
# Turn off interactive plot
if show is False:
plt.ioff()
base_values = float(shap_values["base_values"])
features = (
np.array(shap_values["features"])
if shap_values["features"] is not None
else np.array(shap_values["values"])
)
feature_names = shap_values["feature_names"]
# lower_bounds = shap_values["lower_bounds"]
# upper_bounds = shap_values["upper_bounds"]
values = shap_values["values"]
# init variables we use for tracking the plot locations
num_features = min(max_display, len(values))
# row_height = 0.5
rng = range(num_features - 1, -1, -1)
order = np.argsort(-np.abs(values))
pos_lefts = []
pos_inds = []
pos_widths = []
pos_low = []
pos_high = []
neg_lefts = []
neg_inds = []
neg_widths = []
neg_low = []
neg_high = []
loc = base_values + values.sum()
yticklabels = ["" for _ in range(num_features + 1)]
# size the plot based on how many features we are plotting
# plt.gcf().set_size_inches(8, num_features * row_height + 1.5)
# see how many individual (vs. grouped at the end) features we are plotting
if num_features == len(values):
num_individual = num_features
else:
num_individual = num_features - 1
# compute the locations of the individual features and plot the dashed connecting
# lines
for i in range(num_individual):
sval = values.iloc[order.iloc[i]]
loc -= sval
if sval >= 0:
pos_inds.append(rng[i])
pos_widths.append(sval)
# if lower_bounds is not None:
# pos_low.append(lower_bounds[order[i]])
# pos_high.append(upper_bounds[order[i]])
pos_lefts.append(loc)
else:
neg_inds.append(rng[i])
neg_widths.append(sval)
# if lower_bounds is not None:
# neg_low.append(lower_bounds[order[i]])
# neg_high.append(upper_bounds[order[i]])
neg_lefts.append(loc)
if num_individual != num_features or i + 4 < num_individual:
plt.plot(
[loc, loc],
[rng[i] - 1 - 0.4, rng[i] + 0.4],
color="#bbbbbb",
linestyle="--",
linewidth=0.5,
zorder=-1,
)
if features is None:
yticklabels[rng[i]] = feature_names[order.iloc[i]]
else:
if np.issubdtype(type(features[order.iloc[i]]), np.number):
yticklabels[rng[i]] = (
format_value(float(features[order.iloc[i]]), "%0.03f")
+ " = "
+ feature_names[order.iloc[i]]
)
else:
yticklabels[rng[i]] = (
str(features[order.iloc[i]])
+ " = "
+ str(feature_names[order.iloc[i]])
)
# add a last grouped feature to represent the impact of all the features we didn't
# show
if num_features < len(values):
yticklabels[0] = "%d other features" % (len(values) - num_features + 1)
remaining_impact = base_values - loc
if remaining_impact < 0:
pos_inds.append(0)
pos_widths.append(-remaining_impact)
pos_lefts.append(loc + remaining_impact)
else:
neg_inds.append(0)
neg_widths.append(-remaining_impact)
neg_lefts.append(loc + remaining_impact)
points = (
pos_lefts
+ list(np.array(pos_lefts) + np.array(pos_widths))
+ neg_lefts
+ list(np.array(neg_lefts) + np.array(neg_widths))
)
dataw = np.max(points) - np.min(points)
# draw invisible bars just for sizing the axes
label_padding = np.array([0.1 * dataw if w < 1 else 0 for w in pos_widths])
plt.barh(
pos_inds,
np.array(pos_widths) + label_padding + 0.02 * dataw,
left=np.array(pos_lefts) - 0.01 * dataw,
color=red_rgb,
alpha=0,
)
label_padding = np.array([-0.1 * dataw if -w < 1 else 0 for w in neg_widths])
plt.barh(
neg_inds,
np.array(neg_widths) + label_padding - 0.02 * dataw,
left=np.array(neg_lefts) + 0.01 * dataw,
color=blue_rgb,
alpha=0,
)
# define variable we need for plotting the arrows
head_length = 0.08
bar_width = 0.8
xlen = plt.xlim()[1] - plt.xlim()[0]
fig = plt.gcf()
ax = plt.gca()
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
width = bbox.width
bbox_to_xscale = xlen / width
hl_scaled = bbox_to_xscale * head_length
renderer = fig.canvas.get_renderer()
# draw the positive arrows
for i in range(len(pos_inds)):
dist = pos_widths[i]
arrow_obj = plt.arrow(
pos_lefts[i],
pos_inds[i],
max(dist - hl_scaled, 0.000001),
0,
head_length=min(dist, hl_scaled),
color=red_rgb,
width=bar_width,
head_width=bar_width,
)
if pos_low is not None and i < len(pos_low):
plt.errorbar(
pos_lefts[i] + pos_widths[i],
pos_inds[i],
xerr=np.array(
[[pos_widths[i] - pos_low[i]], [pos_high[i] - pos_widths[i]]]
),
ecolor=light_red_rgb,
)
txt_obj = plt.text(
pos_lefts[i] + 0.5 * dist,
pos_inds[i],
format_value(pos_widths[i], "%+0.02f"),
horizontalalignment="center",
verticalalignment="center",
color="white",
fontsize=12,
)
text_bbox = txt_obj.get_window_extent(renderer=renderer)
arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)
# if the text overflows the arrow then draw it after the arrow
if text_bbox.width > arrow_bbox.width:
txt_obj.remove()
txt_obj = plt.text(
pos_lefts[i] + (5 / 72) * bbox_to_xscale + dist,
pos_inds[i],
format_value(pos_widths[i], "%+0.02f"),
horizontalalignment="left",
verticalalignment="center",
color=red_rgb,
fontsize=12,
)
# draw the negative arrows
for i in range(len(neg_inds)):
dist = neg_widths[i]
arrow_obj = plt.arrow(
neg_lefts[i],
neg_inds[i],
-max(-dist - hl_scaled, 0.000001),
0,
head_length=min(-dist, hl_scaled),
color=blue_rgb,
width=bar_width,
head_width=bar_width,
)
if neg_low is not None and i < len(neg_low):
plt.errorbar(
neg_lefts[i] + neg_widths[i],
neg_inds[i],
xerr=np.array(
[[neg_widths[i] - neg_low[i]], [neg_high[i] - neg_widths[i]]]
),
ecolor=light_blue_rgb,
)
txt_obj = plt.text(
neg_lefts[i] + 0.5 * dist,
neg_inds[i],
format_value(neg_widths[i], "%+0.02f"),
horizontalalignment="center",
verticalalignment="center",
color="white",
fontsize=12,
)
text_bbox = txt_obj.get_window_extent(renderer=renderer)
arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)
# if the text overflows the arrow then draw it after the arrow
if text_bbox.width > arrow_bbox.width:
txt_obj.remove()
txt_obj = plt.text(
neg_lefts[i] - (5 / 72) * bbox_to_xscale + dist,
neg_inds[i],
format_value(neg_widths[i], "%+0.02f"),
horizontalalignment="right",
verticalalignment="center",
color=blue_rgb,
fontsize=12,
)
# draw the y-ticks twice, once in gray and then again with just the feature names in
# black
# The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
ytick_pos = list(range(num_features)) + list(np.arange(num_features) + 1e-8)
plt.yticks(
ytick_pos,
yticklabels[:-1] + [l_.split("=")[-1] for l_ in yticklabels[:-1]],
fontsize=13,
)
# put horizontal lines for each feature row
for i in range(num_features):
plt.axhline(i, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
# mark the prior expected value and the model prediction
plt.axvline(
base_values,
0,
1 / num_features,
color="#bbbbbb",
linestyle="--",
linewidth=0.5,
zorder=-1,
)
fx = base_values + values.sum()
plt.axvline(fx, 0, 1, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1)
# clean up the main axis
plt.gca().xaxis.set_ticks_position("bottom")
plt.gca().yaxis.set_ticks_position("none")
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["left"].set_visible(False)
ax.tick_params(labelsize=13)
# plt.xlabel("\nModel output", fontsize=12)
# draw the E[f(X)] tick mark
xmin, xmax = ax.get_xlim()
ax2 = ax.twiny()
ax2.set_xlim(xmin, xmax)
ax2.set_xticks(
[base_values, base_values + 1e-8]
) # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
ax2.set_xticklabels(
["\n$E[f(X)]$", "\n$ = " + format_value(base_values, "%0.03f") + "$"],
fontsize=12,
ha="left",
)
ax2.spines["right"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.spines["left"].set_visible(False)
# draw the f(x) tick mark
ax3 = ax2.twiny()
ax3.set_xlim(xmin, xmax)
# The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
ax3.set_xticks([base_values + values.sum(), base_values + values.sum() + 1e-8])
ax3.set_xticklabels(
["$f(x)$", "$ = " + format_value(fx, "%0.03f") + "$"], fontsize=12, ha="left"
)
tick_labels = ax3.xaxis.get_majorticklabels()
tick_labels[0].set_transform(
tick_labels[0].get_transform()
+ matplotlib.transforms.ScaledTranslation(-10 / 72.0, 0, fig.dpi_scale_trans)
)
tick_labels[1].set_transform(
tick_labels[1].get_transform()
+ matplotlib.transforms.ScaledTranslation(12 / 72.0, 0, fig.dpi_scale_trans)
)
tick_labels[1].set_color("#999999")
ax3.spines["right"].set_visible(False)
ax3.spines["top"].set_visible(False)
ax3.spines["left"].set_visible(False)
# adjust the position of the E[f(X)] = x.xx label
tick_labels = ax2.xaxis.get_majorticklabels()
tick_labels[0].set_transform(
tick_labels[0].get_transform()
+ matplotlib.transforms.ScaledTranslation(-20 / 72.0, 0, fig.dpi_scale_trans)
)
tick_labels[1].set_transform(
tick_labels[1].get_transform()
+ matplotlib.transforms.ScaledTranslation(
22 / 72.0, -1 / 72.0, fig.dpi_scale_trans
)
)
tick_labels[1].set_color("#999999")
# color the y tick labels that have the feature values as gray
# (these fall behind the black ones with just the feature name)
tick_labels = ax.yaxis.get_majorticklabels()
for i in range(num_features):
tick_labels[i].set_color("#999999")
if show:
plt.show()
else:
return plt.gcf()
[docs]
def waterfall(xai, idx, mean_shapley_value=0):
# ind_values = xai.qoi_.X[idx]
ind_cont = xai.individual(idx)
feature_names = xai.feature_names_.astype(str).tolist()
rank_dict = {
"upper_bounds": None,
"lower_bounds": None,
"features": None, # pd.Series(feature_names),
"data": None, # pd.Series(ind_values, index=feature_names),
"base_values": mean_shapley_value,
"feature_names": feature_names,
"values": pd.Series(ind_cont, index=feature_names),
}
return _waterfall(rank_dict, max_display=10)