Model Explainability with SHAP in Python — Deep Dive
SHAP for Tree-Based Models
TreeExplainer uses a polynomial-time algorithm specific to decision trees, making it practical for production use.
XGBoost Example
import shap
import xgboost as xgb
import pandas as pd
# Train model
X_train = pd.read_parquet("data/train_features.parquet")
y_train = pd.read_parquet("data/train_labels.parquet")["target"]
model = xgb.XGBClassifier(n_estimators=200, max_depth=6)
model.fit(X_train, y_train)
# Compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_train)
# shap_values shape: (n_samples, n_features)
# Each value = contribution of that feature to that prediction
print(f"Base value (average prediction): {explainer.expected_value:.4f}")
print(f"SHAP values shape: {shap_values.shape}")
Explaining a Single Prediction
# Explain prediction for the first sample
sample_idx = 0
sample = X_train.iloc[[sample_idx]]
prediction = model.predict_proba(sample)[0][1]
print(f"Predicted probability: {prediction:.4f}")
print(f"Base value: {explainer.expected_value:.4f}")
print(f"Sum of SHAP values: {shap_values[sample_idx].sum():.4f}")
print(f"Base + SHAP sum ≈ prediction (in log-odds space)")
# Top contributing features for this prediction
feature_contributions = pd.Series(
shap_values[sample_idx], index=X_train.columns
).sort_values(key=abs, ascending=False)
print("\nTop 5 features:")
for feat, val in feature_contributions.head().items():
direction = "↑" if val > 0 else "↓"
print(f" {feat}: {val:+.4f} {direction}")
Visualization
# Beeswarm: global feature importance with direction
shap.plots.beeswarm(shap.Explanation(
values=shap_values,
base_values=explainer.expected_value,
data=X_train,
feature_names=X_train.columns.tolist()
))
# Waterfall: single prediction breakdown
shap.plots.waterfall(shap.Explanation(
values=shap_values[sample_idx],
base_values=explainer.expected_value,
data=X_train.iloc[sample_idx],
feature_names=X_train.columns.tolist()
))
# Dependence plot: feature interaction
shap.plots.scatter(
shap.Explanation(values=shap_values, data=X_train)[:, "income"],
color=shap.Explanation(values=shap_values, data=X_train)[:, "age"]
)
SHAP for Linear Models
LinearExplainer handles regularized linear models where coefficients alone do not fully explain predictions due to feature correlations:
from sklearn.linear_model import LogisticRegression
lr_model = LogisticRegression(C=0.1)
lr_model.fit(X_train, y_train)
explainer = shap.LinearExplainer(lr_model, X_train)
shap_values = explainer.shap_values(X_train)
For uncorrelated features, LinearExplainer produces values identical to coefficient × feature value. For correlated features, it redistributes credit using the data distribution — giving fairer attributions.
SHAP for Deep Learning
DeepExplainer (PyTorch)
import torch
import shap
# Assuming a trained PyTorch model
model.eval()
background = torch.tensor(X_train[:100].values, dtype=torch.float32)
test_samples = torch.tensor(X_test[:20].values, dtype=torch.float32)
explainer = shap.DeepExplainer(model, background)
shap_values = explainer.shap_values(test_samples)
DeepExplainer uses a modified backpropagation (DeepLIFT) to approximate Shapley values. It is faster than KernelExplainer but specific to neural networks.
GradientExplainer (Alternative)
explainer = shap.GradientExplainer(model, background)
shap_values = explainer.shap_values(test_samples)
GradientExplainer uses expected gradients — an extension of integrated gradients that samples reference points from the background distribution. It tends to produce smoother explanations for image and text models.
KernelExplainer: Model-Agnostic
When no specialized explainer exists, KernelExplainer works with any model that exposes a predict function:
# Works with any callable
def predict_fn(X):
return model.predict_proba(X)[:, 1]
explainer = shap.KernelExplainer(predict_fn, shap.sample(X_train, 100))
shap_values = explainer.shap_values(X_test[:10])
KernelExplainer is exponentially slower — use it only for small datasets or when no tree/linear/deep explainer applies.
Feature Interaction Detection
SHAP interaction values extend standard SHAP values to capture pairwise feature interactions:
# Only available for TreeExplainer
interaction_values = explainer.shap_interaction_values(X_train[:500])
# Shape: (n_samples, n_features, n_features)
# Diagonal: main effects (same as standard SHAP values)
# Off-diagonal: interaction effects between feature pairs
# Find strongest interactions
import numpy as np
mean_interactions = np.abs(interaction_values).mean(axis=0)
np.fill_diagonal(mean_interactions, 0) # Remove main effects
# Top interaction pairs
features = X_train.columns
for i in range(len(features)):
for j in range(i + 1, len(features)):
if mean_interactions[i, j] > 0.01: # threshold
print(f"{features[i]} × {features[j]}: {mean_interactions[i, j]:.4f}")
Bias Detection with SHAP
SHAP reveals when a model implicitly uses protected attributes:
def detect_proxy_discrimination(
shap_values: np.ndarray,
feature_names: list[str],
protected_features: list[str],
proxy_threshold: float = 0.1
) -> dict:
"""Detect features that may serve as proxies for protected attributes."""
from scipy.stats import spearmanr
protected_indices = [feature_names.index(f) for f in protected_features]
other_indices = [i for i in range(len(feature_names)) if i not in protected_indices]
proxies = {}
for p_idx in protected_indices:
p_name = feature_names[p_idx]
for o_idx in other_indices:
o_name = feature_names[o_idx]
corr, p_value = spearmanr(shap_values[:, p_idx], shap_values[:, o_idx])
if abs(corr) > proxy_threshold and p_value < 0.05:
proxies[f"{o_name} ↔ {p_name}"] = {
"correlation": corr,
"p_value": p_value
}
return proxies
If a non-protected feature (like zip code) has SHAP values highly correlated with a protected feature (like race), the model may be discriminating through proxies.
Production Integration
Serving SHAP Explanations via API
from fastapi import FastAPI
from pydantic import BaseModel
import shap
import numpy as np
app = FastAPI()
# Pre-load model and explainer at startup
model = load_model()
explainer = shap.TreeExplainer(model)
class PredictionRequest(BaseModel):
features: dict[str, float]
class ExplanationResponse(BaseModel):
prediction: float
base_value: float
feature_contributions: dict[str, float]
top_positive: list[dict]
top_negative: list[dict]
@app.post("/predict-explain", response_model=ExplanationResponse)
def predict_with_explanation(request: PredictionRequest):
import pandas as pd
features_df = pd.DataFrame([request.features])
prediction = float(model.predict_proba(features_df)[0][1])
shap_vals = explainer.shap_values(features_df)[0]
contributions = dict(zip(request.features.keys(), shap_vals.tolist()))
sorted_contribs = sorted(contributions.items(), key=lambda x: abs(x[1]), reverse=True)
top_pos = [{"feature": k, "impact": v} for k, v in sorted_contribs if v > 0][:5]
top_neg = [{"feature": k, "impact": v} for k, v in sorted_contribs if v < 0][:5]
return ExplanationResponse(
prediction=prediction,
base_value=float(explainer.expected_value),
feature_contributions=contributions,
top_positive=top_pos,
top_negative=top_neg
)
Performance Considerations
SHAP computation adds latency to predictions:
| Explainer | Typical Latency per Sample |
|---|---|
| TreeExplainer | 0.1-5ms |
| LinearExplainer | <0.1ms |
| DeepExplainer | 10-100ms |
| KernelExplainer | 1-60 seconds |
For real-time serving, TreeExplainer and LinearExplainer are fast enough to include inline. For slower explainers, compute explanations asynchronously and cache them.
Caching Explanations
from functools import lru_cache
import hashlib
@lru_cache(maxsize=10_000)
def cached_explanation(features_hash: str):
"""Cache SHAP explanations for repeated feature vectors."""
features = decode_features(features_hash)
return explainer.shap_values(features)
def get_explanation(features_df):
features_hash = hashlib.md5(
features_df.values.tobytes()
).hexdigest()
return cached_explanation(features_hash)
SHAP vs Other Explanation Methods
| Method | Theoretical Foundation | Consistency | Speed | Model-Agnostic |
|---|---|---|---|---|
| SHAP | Shapley values (game theory) | Guaranteed | Varies | Yes |
| LIME | Local linear approximation | Not guaranteed | Fast | Yes |
| Feature importance | Impurity-based | Biased toward high-cardinality | Fast | No (trees only) |
| Permutation importance | Prediction change on shuffle | Affected by correlation | Medium | Yes |
| Integrated Gradients | Path integral of gradients | Guaranteed | Medium | No (differentiable) |
SHAP’s mathematical guarantees (local accuracy, consistency, and missingness) make it the strongest choice for regulated environments where explanations must be defensible.
One thing to remember: SHAP transforms any model from an opaque decision-maker into an auditable system where every prediction comes with a mathematically rigorous breakdown of why each feature pushed the outcome in its direction.
See Also
- Python Ab Testing Ml Models Why taste-testing two cookie recipes with different friends is the fairest way to pick a winner.
- Python Feature Store Design Why a shared ingredient pantry saves every cook in the kitchen from buying the same spices over and over.
- Python Ml Pipeline Orchestration Why a factory assembly line needs a foreman to make sure every step happens in the right order at the right time.
- Python Mlflow Experiment Tracking Find out why writing down every cooking experiment helps you recreate the perfect recipe every time.
- Python Model Monitoring Drift Why a weather forecast that was perfect last summer might completely fail this winter — and how to catch it early.