Model Monitoring and Drift Detection in Python — Deep Dive
Drift Detection with Evidently
Evidently is the most popular open-source library for ML monitoring in Python. It generates drift reports, data quality checks, and model performance dashboards.
Basic Drift Report
import pandas as pd
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset
# Reference: training data distribution
reference_data = pd.read_parquet("data/training_sample.parquet")
# Current: recent production data
current_data = pd.read_parquet("data/production_last_week.parquet")
report = Report(metrics=[DataDriftPreset()])
report.run(reference_data=reference_data, current_data=current_data)
report.save_html("drift_report.html")
# Programmatic access to results
result = report.as_dict()
drift_detected = result["metrics"][0]["result"]["dataset_drift"]
drifted_features = [
col["column_name"]
for col in result["metrics"][0]["result"]["drift_by_columns"].values()
if col["drift_detected"]
]
print(f"Drift detected: {drift_detected}")
print(f"Drifted features: {drifted_features}")
Evidently Test Suites for CI/CD
from evidently.test_suite import TestSuite
from evidently.tests import (
TestShareOfDriftedColumns,
TestColumnDrift,
TestShareOfMissingValues,
)
suite = TestSuite(tests=[
TestShareOfDriftedColumns(lt=0.3), # Less than 30% of columns drifted
TestColumnDrift(column_name="income"),
TestShareOfMissingValues(lt=0.05),
])
suite.run(reference_data=reference_data, current_data=current_data)
# Use in CI: fail if tests don't pass
if not suite.as_dict()["summary"]["all_passed"]:
raise RuntimeError("Data drift tests failed — investigate before retraining")
Statistical Tests for Drift
Population Stability Index (PSI)
PSI is the industry standard for monitoring score distributions in finance:
import numpy as np
def calculate_psi(reference: np.ndarray, current: np.ndarray, bins: int = 10) -> float:
"""Calculate Population Stability Index between two distributions."""
# Create bins from reference distribution
breakpoints = np.percentile(reference, np.linspace(0, 100, bins + 1))
breakpoints[0] = -np.inf
breakpoints[-1] = np.inf
ref_counts = np.histogram(reference, bins=breakpoints)[0]
cur_counts = np.histogram(current, bins=breakpoints)[0]
# Convert to proportions with smoothing
ref_pct = (ref_counts + 1) / (len(reference) + bins)
cur_pct = (cur_counts + 1) / (len(current) + bins)
psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
return psi
# Interpretation:
# PSI < 0.1: No significant drift
# 0.1 <= PSI < 0.25: Moderate drift — investigate
# PSI >= 0.25: Significant drift — action required
Kolmogorov-Smirnov Test per Feature
from scipy import stats
def detect_drift_ks(
reference_df: pd.DataFrame,
current_df: pd.DataFrame,
numerical_columns: list[str],
alpha: float = 0.05
) -> dict:
"""Run KS test on each numerical feature."""
results = {}
for col in numerical_columns:
stat, p_value = stats.ks_2samp(reference_df[col], current_df[col])
results[col] = {
"statistic": stat,
"p_value": p_value,
"drift_detected": p_value < alpha
}
return results
Jensen-Shannon Divergence for Categorical Features
from scipy.spatial.distance import jensenshannon
from collections import Counter
def js_divergence_categorical(
reference: pd.Series,
current: pd.Series,
threshold: float = 0.1
) -> dict:
"""Jensen-Shannon divergence for categorical distributions."""
all_categories = set(reference.unique()) | set(current.unique())
ref_counts = Counter(reference)
cur_counts = Counter(current)
ref_total = sum(ref_counts.values())
cur_total = sum(cur_counts.values())
ref_dist = [ref_counts.get(cat, 0) / ref_total for cat in all_categories]
cur_dist = [cur_counts.get(cat, 0) / cur_total for cat in all_categories]
js_dist = jensenshannon(ref_dist, cur_dist)
return {
"js_divergence": js_dist,
"drift_detected": js_dist > threshold
}
NannyML: Monitoring Without Ground Truth
NannyML estimates model performance without requiring labels using Confidence-Based Performance Estimation (CBPE):
import nannyml as nml
# Reference data: period where you know the ground truth
reference = pd.read_parquet("data/reference_with_labels.parquet")
# Analysis data: production data without labels
analysis = pd.read_parquet("data/production_no_labels.parquet")
estimator = nml.CBPE(
y_pred_proba="predicted_probability",
y_pred="predicted_class",
y_true="actual_class", # Only needed for reference
problem_type="classification_binary",
metrics=["roc_auc", "f1"],
chunk_size=5000
)
estimator.fit(reference)
results = estimator.estimate(analysis)
# Check for performance drops
for chunk in results.filter(period="analysis").to_df().itertuples():
if chunk.roc_auc_alert:
print(f"Chunk {chunk.start_date}: estimated ROC AUC dropped to {chunk.roc_auc:.3f}")
Building a Monitoring Pipeline
Architecture
Production API → Prediction Logger → Message Queue (Kafka)
↓
Drift Analyzer (scheduled)
↓
Metrics Store (Prometheus)
↓
Dashboard (Grafana) + Alerts
Prediction Logger
import json
import time
from kafka import KafkaProducer
producer = KafkaProducer(
bootstrap_servers="kafka:9092",
value_serializer=lambda v: json.dumps(v).encode("utf-8")
)
def log_prediction(features: dict, prediction: float, model_version: str):
"""Log every prediction for downstream monitoring."""
producer.send("ml-predictions", {
"timestamp": time.time(),
"model_version": model_version,
"features": features,
"prediction": prediction
})
Scheduled Drift Analyzer
from datetime import datetime, timedelta
def run_drift_check(
reference_path: str,
lookback_hours: int = 24,
alert_threshold: float = 0.25
):
"""Periodic drift check comparing recent production data to reference."""
reference = pd.read_parquet(reference_path)
# Load recent production data from warehouse
cutoff = datetime.utcnow() - timedelta(hours=lookback_hours)
current = load_production_data(since=cutoff)
if len(current) < 100:
return {"status": "insufficient_data", "sample_size": len(current)}
numerical_cols = reference.select_dtypes(include="number").columns.tolist()
drift_results = {}
for col in numerical_cols:
psi = calculate_psi(reference[col].values, current[col].values)
drift_results[col] = {
"psi": psi,
"alert": psi >= alert_threshold
}
alerted_features = [c for c, r in drift_results.items() if r["alert"]]
if alerted_features:
send_alert(
severity="warning",
message=f"Drift detected in {len(alerted_features)} features: "
f"{', '.join(alerted_features[:5])}"
)
# Push metrics to Prometheus
for col, result in drift_results.items():
push_metric(f"ml_drift_psi_{col}", result["psi"])
return drift_results
Automated Retraining Triggers
When drift is confirmed, the system can trigger retraining automatically:
def evaluate_retraining_need(drift_results: dict, config: dict) -> str:
"""Decide retraining action based on drift severity."""
alerted = [c for c, r in drift_results.items() if r["alert"]]
total_features = len(drift_results)
drift_ratio = len(alerted) / total_features if total_features > 0 else 0
if drift_ratio >= config["critical_threshold"]: # e.g., 0.5
return "retrain_immediately"
elif drift_ratio >= config["warning_threshold"]: # e.g., 0.2
return "schedule_retraining"
else:
return "no_action"
The key tradeoff: retraining too aggressively wastes compute and risks introducing noisy data. Retraining too rarely lets degradation accumulate. Most teams retrain when drift crosses a threshold and estimated performance drops below an acceptable level.
Monitoring Checklist for Production Models
- Log everything — inputs, outputs, latency, model version on every prediction
- Set reference baselines — save training data statistics as the comparison point
- Run drift checks daily — hourly for high-stakes models (fraud, safety)
- Track business metrics — technical metrics alone miss problems that affect users
- Alert on missing data spikes — upstream pipeline failures show up as nulls before they show up as drift
- Version your monitors — drift thresholds and alert rules should be in source control
- Simulate drift — periodically inject synthetic drift to verify your monitoring catches it
One thing to remember: The best monitoring systems detect problems from the input side (data drift) because waiting for output degradation means the model has already been making bad predictions for hours or days.
See Also
- Python Ab Testing Ml Models Why taste-testing two cookie recipes with different friends is the fairest way to pick a winner.
- Python Feature Store Design Why a shared ingredient pantry saves every cook in the kitchen from buying the same spices over and over.
- Python Ml Pipeline Orchestration Why a factory assembly line needs a foreman to make sure every step happens in the right order at the right time.
- Python Mlflow Experiment Tracking Find out why writing down every cooking experiment helps you recreate the perfect recipe every time.
- Python Model Explainability Shap How asking 'why did you pick that answer?' turns a mysterious black box into something you can actually trust.