Model Monitoring and Drift Detection in Python — Deep Dive

Drift Detection with Evidently

Evidently is the most popular open-source library for ML monitoring in Python. It generates drift reports, data quality checks, and model performance dashboards.

Basic Drift Report

import pandas as pd
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset

# Reference: training data distribution
reference_data = pd.read_parquet("data/training_sample.parquet")
# Current: recent production data
current_data = pd.read_parquet("data/production_last_week.parquet")

report = Report(metrics=[DataDriftPreset()])
report.run(reference_data=reference_data, current_data=current_data)
report.save_html("drift_report.html")

# Programmatic access to results
result = report.as_dict()
drift_detected = result["metrics"][0]["result"]["dataset_drift"]
drifted_features = [
    col["column_name"]
    for col in result["metrics"][0]["result"]["drift_by_columns"].values()
    if col["drift_detected"]
]
print(f"Drift detected: {drift_detected}")
print(f"Drifted features: {drifted_features}")

Evidently Test Suites for CI/CD

from evidently.test_suite import TestSuite
from evidently.tests import (
    TestShareOfDriftedColumns,
    TestColumnDrift,
    TestShareOfMissingValues,
)

suite = TestSuite(tests=[
    TestShareOfDriftedColumns(lt=0.3),  # Less than 30% of columns drifted
    TestColumnDrift(column_name="income"),
    TestShareOfMissingValues(lt=0.05),
])

suite.run(reference_data=reference_data, current_data=current_data)

# Use in CI: fail if tests don't pass
if not suite.as_dict()["summary"]["all_passed"]:
    raise RuntimeError("Data drift tests failed — investigate before retraining")

Statistical Tests for Drift

Population Stability Index (PSI)

PSI is the industry standard for monitoring score distributions in finance:

import numpy as np

def calculate_psi(reference: np.ndarray, current: np.ndarray, bins: int = 10) -> float:
    """Calculate Population Stability Index between two distributions."""
    # Create bins from reference distribution
    breakpoints = np.percentile(reference, np.linspace(0, 100, bins + 1))
    breakpoints[0] = -np.inf
    breakpoints[-1] = np.inf

    ref_counts = np.histogram(reference, bins=breakpoints)[0]
    cur_counts = np.histogram(current, bins=breakpoints)[0]

    # Convert to proportions with smoothing
    ref_pct = (ref_counts + 1) / (len(reference) + bins)
    cur_pct = (cur_counts + 1) / (len(current) + bins)

    psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
    return psi

# Interpretation:
# PSI < 0.1: No significant drift
# 0.1 <= PSI < 0.25: Moderate drift — investigate
# PSI >= 0.25: Significant drift — action required

Kolmogorov-Smirnov Test per Feature

from scipy import stats

def detect_drift_ks(
    reference_df: pd.DataFrame,
    current_df: pd.DataFrame,
    numerical_columns: list[str],
    alpha: float = 0.05
) -> dict:
    """Run KS test on each numerical feature."""
    results = {}
    for col in numerical_columns:
        stat, p_value = stats.ks_2samp(reference_df[col], current_df[col])
        results[col] = {
            "statistic": stat,
            "p_value": p_value,
            "drift_detected": p_value < alpha
        }
    return results

Jensen-Shannon Divergence for Categorical Features

from scipy.spatial.distance import jensenshannon
from collections import Counter

def js_divergence_categorical(
    reference: pd.Series,
    current: pd.Series,
    threshold: float = 0.1
) -> dict:
    """Jensen-Shannon divergence for categorical distributions."""
    all_categories = set(reference.unique()) | set(current.unique())

    ref_counts = Counter(reference)
    cur_counts = Counter(current)

    ref_total = sum(ref_counts.values())
    cur_total = sum(cur_counts.values())

    ref_dist = [ref_counts.get(cat, 0) / ref_total for cat in all_categories]
    cur_dist = [cur_counts.get(cat, 0) / cur_total for cat in all_categories]

    js_dist = jensenshannon(ref_dist, cur_dist)

    return {
        "js_divergence": js_dist,
        "drift_detected": js_dist > threshold
    }

NannyML: Monitoring Without Ground Truth

NannyML estimates model performance without requiring labels using Confidence-Based Performance Estimation (CBPE):

import nannyml as nml

# Reference data: period where you know the ground truth
reference = pd.read_parquet("data/reference_with_labels.parquet")
# Analysis data: production data without labels
analysis = pd.read_parquet("data/production_no_labels.parquet")

estimator = nml.CBPE(
    y_pred_proba="predicted_probability",
    y_pred="predicted_class",
    y_true="actual_class",  # Only needed for reference
    problem_type="classification_binary",
    metrics=["roc_auc", "f1"],
    chunk_size=5000
)

estimator.fit(reference)
results = estimator.estimate(analysis)

# Check for performance drops
for chunk in results.filter(period="analysis").to_df().itertuples():
    if chunk.roc_auc_alert:
        print(f"Chunk {chunk.start_date}: estimated ROC AUC dropped to {chunk.roc_auc:.3f}")

Building a Monitoring Pipeline

Architecture

Production API → Prediction Logger → Message Queue (Kafka)

                              Drift Analyzer (scheduled)

                              Metrics Store (Prometheus)

                              Dashboard (Grafana) + Alerts

Prediction Logger

import json
import time
from kafka import KafkaProducer

producer = KafkaProducer(
    bootstrap_servers="kafka:9092",
    value_serializer=lambda v: json.dumps(v).encode("utf-8")
)

def log_prediction(features: dict, prediction: float, model_version: str):
    """Log every prediction for downstream monitoring."""
    producer.send("ml-predictions", {
        "timestamp": time.time(),
        "model_version": model_version,
        "features": features,
        "prediction": prediction
    })

Scheduled Drift Analyzer

from datetime import datetime, timedelta

def run_drift_check(
    reference_path: str,
    lookback_hours: int = 24,
    alert_threshold: float = 0.25
):
    """Periodic drift check comparing recent production data to reference."""
    reference = pd.read_parquet(reference_path)

    # Load recent production data from warehouse
    cutoff = datetime.utcnow() - timedelta(hours=lookback_hours)
    current = load_production_data(since=cutoff)

    if len(current) < 100:
        return {"status": "insufficient_data", "sample_size": len(current)}

    numerical_cols = reference.select_dtypes(include="number").columns.tolist()
    drift_results = {}

    for col in numerical_cols:
        psi = calculate_psi(reference[col].values, current[col].values)
        drift_results[col] = {
            "psi": psi,
            "alert": psi >= alert_threshold
        }

    alerted_features = [c for c, r in drift_results.items() if r["alert"]]

    if alerted_features:
        send_alert(
            severity="warning",
            message=f"Drift detected in {len(alerted_features)} features: "
                    f"{', '.join(alerted_features[:5])}"
        )

    # Push metrics to Prometheus
    for col, result in drift_results.items():
        push_metric(f"ml_drift_psi_{col}", result["psi"])

    return drift_results

Automated Retraining Triggers

When drift is confirmed, the system can trigger retraining automatically:

def evaluate_retraining_need(drift_results: dict, config: dict) -> str:
    """Decide retraining action based on drift severity."""
    alerted = [c for c, r in drift_results.items() if r["alert"]]
    total_features = len(drift_results)
    drift_ratio = len(alerted) / total_features if total_features > 0 else 0

    if drift_ratio >= config["critical_threshold"]:  # e.g., 0.5
        return "retrain_immediately"
    elif drift_ratio >= config["warning_threshold"]:  # e.g., 0.2
        return "schedule_retraining"
    else:
        return "no_action"

The key tradeoff: retraining too aggressively wastes compute and risks introducing noisy data. Retraining too rarely lets degradation accumulate. Most teams retrain when drift crosses a threshold and estimated performance drops below an acceptable level.

Monitoring Checklist for Production Models

  1. Log everything — inputs, outputs, latency, model version on every prediction
  2. Set reference baselines — save training data statistics as the comparison point
  3. Run drift checks daily — hourly for high-stakes models (fraud, safety)
  4. Track business metrics — technical metrics alone miss problems that affect users
  5. Alert on missing data spikes — upstream pipeline failures show up as nulls before they show up as drift
  6. Version your monitors — drift thresholds and alert rules should be in source control
  7. Simulate drift — periodically inject synthetic drift to verify your monitoring catches it

One thing to remember: The best monitoring systems detect problems from the input side (data drift) because waiting for output degradation means the model has already been making bad predictions for hours or days.

pythonmodel-monitoringdrift-detectionmlops

See Also