Model Explainability with SHAP in Python — Deep Dive

SHAP for Tree-Based Models

TreeExplainer uses a polynomial-time algorithm specific to decision trees, making it practical for production use.

XGBoost Example

import shap
import xgboost as xgb
import pandas as pd

# Train model
X_train = pd.read_parquet("data/train_features.parquet")
y_train = pd.read_parquet("data/train_labels.parquet")["target"]
model = xgb.XGBClassifier(n_estimators=200, max_depth=6)
model.fit(X_train, y_train)

# Compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_train)

# shap_values shape: (n_samples, n_features)
# Each value = contribution of that feature to that prediction
print(f"Base value (average prediction): {explainer.expected_value:.4f}")
print(f"SHAP values shape: {shap_values.shape}")

Explaining a Single Prediction

# Explain prediction for the first sample
sample_idx = 0
sample = X_train.iloc[[sample_idx]]
prediction = model.predict_proba(sample)[0][1]

print(f"Predicted probability: {prediction:.4f}")
print(f"Base value: {explainer.expected_value:.4f}")
print(f"Sum of SHAP values: {shap_values[sample_idx].sum():.4f}")
print(f"Base + SHAP sum ≈ prediction (in log-odds space)")

# Top contributing features for this prediction
feature_contributions = pd.Series(
    shap_values[sample_idx], index=X_train.columns
).sort_values(key=abs, ascending=False)

print("\nTop 5 features:")
for feat, val in feature_contributions.head().items():
    direction = "↑" if val > 0 else "↓"
    print(f"  {feat}: {val:+.4f} {direction}")

Visualization

# Beeswarm: global feature importance with direction
shap.plots.beeswarm(shap.Explanation(
    values=shap_values,
    base_values=explainer.expected_value,
    data=X_train,
    feature_names=X_train.columns.tolist()
))

# Waterfall: single prediction breakdown
shap.plots.waterfall(shap.Explanation(
    values=shap_values[sample_idx],
    base_values=explainer.expected_value,
    data=X_train.iloc[sample_idx],
    feature_names=X_train.columns.tolist()
))

# Dependence plot: feature interaction
shap.plots.scatter(
    shap.Explanation(values=shap_values, data=X_train)[:, "income"],
    color=shap.Explanation(values=shap_values, data=X_train)[:, "age"]
)

SHAP for Linear Models

LinearExplainer handles regularized linear models where coefficients alone do not fully explain predictions due to feature correlations:

from sklearn.linear_model import LogisticRegression

lr_model = LogisticRegression(C=0.1)
lr_model.fit(X_train, y_train)

explainer = shap.LinearExplainer(lr_model, X_train)
shap_values = explainer.shap_values(X_train)

For uncorrelated features, LinearExplainer produces values identical to coefficient × feature value. For correlated features, it redistributes credit using the data distribution — giving fairer attributions.

SHAP for Deep Learning

DeepExplainer (PyTorch)

import torch
import shap

# Assuming a trained PyTorch model
model.eval()

background = torch.tensor(X_train[:100].values, dtype=torch.float32)
test_samples = torch.tensor(X_test[:20].values, dtype=torch.float32)

explainer = shap.DeepExplainer(model, background)
shap_values = explainer.shap_values(test_samples)

DeepExplainer uses a modified backpropagation (DeepLIFT) to approximate Shapley values. It is faster than KernelExplainer but specific to neural networks.

GradientExplainer (Alternative)

explainer = shap.GradientExplainer(model, background)
shap_values = explainer.shap_values(test_samples)

GradientExplainer uses expected gradients — an extension of integrated gradients that samples reference points from the background distribution. It tends to produce smoother explanations for image and text models.

KernelExplainer: Model-Agnostic

When no specialized explainer exists, KernelExplainer works with any model that exposes a predict function:

# Works with any callable
def predict_fn(X):
    return model.predict_proba(X)[:, 1]

explainer = shap.KernelExplainer(predict_fn, shap.sample(X_train, 100))
shap_values = explainer.shap_values(X_test[:10])

KernelExplainer is exponentially slower — use it only for small datasets or when no tree/linear/deep explainer applies.

Feature Interaction Detection

SHAP interaction values extend standard SHAP values to capture pairwise feature interactions:

# Only available for TreeExplainer
interaction_values = explainer.shap_interaction_values(X_train[:500])
# Shape: (n_samples, n_features, n_features)

# Diagonal: main effects (same as standard SHAP values)
# Off-diagonal: interaction effects between feature pairs

# Find strongest interactions
import numpy as np

mean_interactions = np.abs(interaction_values).mean(axis=0)
np.fill_diagonal(mean_interactions, 0)  # Remove main effects

# Top interaction pairs
features = X_train.columns
for i in range(len(features)):
    for j in range(i + 1, len(features)):
        if mean_interactions[i, j] > 0.01:  # threshold
            print(f"{features[i]} × {features[j]}: {mean_interactions[i, j]:.4f}")

Bias Detection with SHAP

SHAP reveals when a model implicitly uses protected attributes:

def detect_proxy_discrimination(
    shap_values: np.ndarray,
    feature_names: list[str],
    protected_features: list[str],
    proxy_threshold: float = 0.1
) -> dict:
    """Detect features that may serve as proxies for protected attributes."""
    from scipy.stats import spearmanr

    protected_indices = [feature_names.index(f) for f in protected_features]
    other_indices = [i for i in range(len(feature_names)) if i not in protected_indices]

    proxies = {}
    for p_idx in protected_indices:
        p_name = feature_names[p_idx]
        for o_idx in other_indices:
            o_name = feature_names[o_idx]
            corr, p_value = spearmanr(shap_values[:, p_idx], shap_values[:, o_idx])
            if abs(corr) > proxy_threshold and p_value < 0.05:
                proxies[f"{o_name}{p_name}"] = {
                    "correlation": corr,
                    "p_value": p_value
                }

    return proxies

If a non-protected feature (like zip code) has SHAP values highly correlated with a protected feature (like race), the model may be discriminating through proxies.

Production Integration

Serving SHAP Explanations via API

from fastapi import FastAPI
from pydantic import BaseModel
import shap
import numpy as np

app = FastAPI()

# Pre-load model and explainer at startup
model = load_model()
explainer = shap.TreeExplainer(model)

class PredictionRequest(BaseModel):
    features: dict[str, float]

class ExplanationResponse(BaseModel):
    prediction: float
    base_value: float
    feature_contributions: dict[str, float]
    top_positive: list[dict]
    top_negative: list[dict]

@app.post("/predict-explain", response_model=ExplanationResponse)
def predict_with_explanation(request: PredictionRequest):
    import pandas as pd

    features_df = pd.DataFrame([request.features])
    prediction = float(model.predict_proba(features_df)[0][1])
    shap_vals = explainer.shap_values(features_df)[0]

    contributions = dict(zip(request.features.keys(), shap_vals.tolist()))

    sorted_contribs = sorted(contributions.items(), key=lambda x: abs(x[1]), reverse=True)
    top_pos = [{"feature": k, "impact": v} for k, v in sorted_contribs if v > 0][:5]
    top_neg = [{"feature": k, "impact": v} for k, v in sorted_contribs if v < 0][:5]

    return ExplanationResponse(
        prediction=prediction,
        base_value=float(explainer.expected_value),
        feature_contributions=contributions,
        top_positive=top_pos,
        top_negative=top_neg
    )

Performance Considerations

SHAP computation adds latency to predictions:

ExplainerTypical Latency per Sample
TreeExplainer0.1-5ms
LinearExplainer<0.1ms
DeepExplainer10-100ms
KernelExplainer1-60 seconds

For real-time serving, TreeExplainer and LinearExplainer are fast enough to include inline. For slower explainers, compute explanations asynchronously and cache them.

Caching Explanations

from functools import lru_cache
import hashlib

@lru_cache(maxsize=10_000)
def cached_explanation(features_hash: str):
    """Cache SHAP explanations for repeated feature vectors."""
    features = decode_features(features_hash)
    return explainer.shap_values(features)

def get_explanation(features_df):
    features_hash = hashlib.md5(
        features_df.values.tobytes()
    ).hexdigest()
    return cached_explanation(features_hash)

SHAP vs Other Explanation Methods

MethodTheoretical FoundationConsistencySpeedModel-Agnostic
SHAPShapley values (game theory)GuaranteedVariesYes
LIMELocal linear approximationNot guaranteedFastYes
Feature importanceImpurity-basedBiased toward high-cardinalityFastNo (trees only)
Permutation importancePrediction change on shuffleAffected by correlationMediumYes
Integrated GradientsPath integral of gradientsGuaranteedMediumNo (differentiable)

SHAP’s mathematical guarantees (local accuracy, consistency, and missingness) make it the strongest choice for regulated environments where explanations must be defensible.

One thing to remember: SHAP transforms any model from an opaque decision-maker into an auditable system where every prediction comes with a mathematically rigorous breakdown of why each feature pushed the outcome in its direction.

pythonshapexplainabilitymachine-learning

See Also