Confusion Matrix in Python — Deep Dive

Building a Confusion Matrix in Scikit-Learn

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

cm = confusion_matrix(y_test, y_pred)
print(cm)

The output is a NumPy array where cm[i][j] is the count of samples with true label i predicted as j.

Visualization with Matplotlib

Basic Heatmap

import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay

fig, ax = plt.subplots(figsize=(8, 6))
disp = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=["Setosa", "Versicolor", "Virginica"],
)
disp.plot(cmap="Blues", ax=ax, values_format="d")
ax.set_title("Iris Classification — Confusion Matrix")
plt.tight_layout()
plt.savefig("confusion_matrix.png", dpi=150)

Normalized Heatmap

Normalizing by rows shows recall per class; normalizing by columns shows precision per class:

cm_normalized = confusion_matrix(y_test, y_pred, normalize="true")

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ConfusionMatrixDisplay(cm, display_labels=["Setosa", "Versicolor", "Virginica"]).plot(
    ax=axes[0], cmap="Blues", values_format="d"
)
axes[0].set_title("Raw Counts")

ConfusionMatrixDisplay(cm_normalized, display_labels=["Setosa", "Versicolor", "Virginica"]).plot(
    ax=axes[1], cmap="Oranges", values_format=".2%"
)
axes[1].set_title("Normalized (Row = Recall)")

plt.tight_layout()
plt.savefig("confusion_matrix_comparison.png", dpi=150)

Extracting Metrics from the Matrix

Every standard classification metric can be derived directly from the confusion matrix:

import numpy as np

def metrics_from_cm(cm):
    """Extract per-class and aggregate metrics from a confusion matrix."""
    n_classes = cm.shape[0]
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (tp + fp + fn)
    
    precision = tp / (tp + fp + 1e-10)
    recall = tp / (tp + fn + 1e-10)
    f1 = 2 * precision * recall / (precision + recall + 1e-10)
    
    return {
        "per_class_precision": precision,
        "per_class_recall": recall,
        "per_class_f1": f1,
        "macro_f1": f1.mean(),
        "accuracy": tp.sum() / cm.sum(),
    }

metrics = metrics_from_cm(cm)
for key, val in metrics.items():
    print(f"{key}: {val}")

Multi-Label Confusion Matrices

For multi-label classification where each sample can have multiple labels:

from sklearn.metrics import multilabel_confusion_matrix

y_true_ml = [[1, 0, 1], [0, 1, 1], [1, 1, 0]]
y_pred_ml = [[1, 0, 0], [0, 1, 1], [1, 0, 0]]

mcm = multilabel_confusion_matrix(y_true_ml, y_pred_ml)
# Returns one 2x2 matrix per label
for i, label_cm in enumerate(mcm):
    print(f"Label {i}:\n{label_cm}\n")

Cost-Sensitive Analysis

Not all errors are equal. In fraud detection, missing fraud (FN) costs far more than a false alarm (FP). The confusion matrix enables cost-weighted evaluation:

# Cost matrix: cost[i][j] = cost of predicting j when true label is i
cost_matrix = np.array([
    [0, 10],     # True Negative costs 0, False Positive costs $10
    [500, 0],    # False Negative costs $500, True Positive costs 0
])

total_cost = np.sum(cm * cost_matrix)
cost_per_sample = total_cost / cm.sum()
print(f"Total cost: ${total_cost:,.0f}")
print(f"Cost per prediction: ${cost_per_sample:.2f}")

This transforms the confusion matrix from a statistical tool into a business decision tool. You can now optimize the classification threshold to minimize total cost:

from sklearn.metrics import confusion_matrix

best_cost = float("inf")
best_threshold = 0.5

for threshold in np.arange(0.1, 0.95, 0.01):
    y_pred_t = (y_proba >= threshold).astype(int)
    cm_t = confusion_matrix(y_true, y_pred_t)
    cost_t = np.sum(cm_t * cost_matrix)
    if cost_t < best_cost:
        best_cost = cost_t
        best_threshold = threshold

print(f"Optimal threshold: {best_threshold:.2f}, Cost: ${best_cost:,.0f}")

Error Analysis Workflow

The confusion matrix is the starting point for systematic error analysis:

Step 1: Identify Problem Pairs

# Find the largest off-diagonal values
cm_no_diag = cm.copy()
np.fill_diagonal(cm_no_diag, 0)
worst_pair = np.unravel_index(cm_no_diag.argmax(), cm_no_diag.shape)
print(f"Most confused: true={worst_pair[0]}, predicted={worst_pair[1]}")

Step 2: Inspect Misclassified Samples

# Get indices of samples confused between classes 1 and 2
misclassified = np.where((y_test == worst_pair[0]) & (y_pred == worst_pair[1]))[0]
print(f"Number of errors: {len(misclassified)}")
print(f"Sample indices: {misclassified[:10]}")

# Inspect feature distributions of misclassified vs. correctly classified
import pandas as pd

df_test = pd.DataFrame(X_test, columns=feature_names)
df_test["true"] = y_test
df_test["pred"] = y_pred
df_test["correct"] = y_test == y_pred

# Compare feature means
print(df_test[df_test["true"] == worst_pair[0]].groupby("correct").mean())

Step 3: Build Targeted Fixes

Once you know which features differ between correctly and incorrectly classified samples, you can:

  • Engineer features that distinguish the confused classes.
  • Collect more training data for the problem pair.
  • Use class-specific thresholds.
  • Train a specialized sub-model for the ambiguous pair.

Confusion Matrix Over Time

In production, track the confusion matrix on a rolling window to detect drift:

from collections import deque

class RollingConfusionMatrix:
    def __init__(self, window_size=1000, n_classes=2):
        self.buffer = deque(maxlen=window_size)
        self.n_classes = n_classes
    
    def update(self, y_true_batch, y_pred_batch):
        for yt, yp in zip(y_true_batch, y_pred_batch):
            self.buffer.append((yt, yp))
    
    def get_matrix(self):
        if not self.buffer:
            return np.zeros((self.n_classes, self.n_classes), dtype=int)
        yt, yp = zip(*self.buffer)
        return confusion_matrix(yt, yp, labels=range(self.n_classes))

Alert when off-diagonal proportions exceed historical baselines.

Advanced Visualization: Sankey Diagrams

For large multi-class problems (10+ classes), heatmaps get crowded. A Sankey diagram shows flows from true labels to predicted labels, making major confusion patterns obvious:

import plotly.graph_objects as go

labels = class_names + [f"Pred: {c}" for c in class_names]
source, target, value = [], [], []

for i in range(n_classes):
    for j in range(n_classes):
        if cm[i][j] > 0:
            source.append(i)
            target.append(n_classes + j)
            value.append(int(cm[i][j]))

fig = go.Figure(go.Sankey(
    node=dict(label=labels),
    link=dict(source=source, target=target, value=value),
))
fig.update_layout(title="Classification Flow: True → Predicted")
fig.show()

Common Pitfalls

  1. Transposing rows and columns: Scikit-learn convention is rows = true, columns = predicted. Some libraries use the opposite. Always verify.
  2. Ignoring normalization: Raw counts mislead when class sizes differ. Normalize before drawing conclusions about per-class performance.
  3. Only looking at the diagonal: The diagonal tells you what went right; the off-diagonal tells you what needs fixing.
  4. Forgetting to include all labels: If a class has zero test samples, pass labels= explicitly to avoid a missing row/column.

One thing to remember: The confusion matrix is not just a reporting tool — it is a diagnostic instrument. The errors it reveals are the roadmap for your next round of improvements.

pythonconfusion-matrixmachine-learningclassification

See Also

  • Python Cross Validation Find out why testing a computer's homework on different practice sets keeps it from cheating.
  • Python Model Evaluation Metrics Discover why asking 'how good is my model?' needs more than one number to get an honest answer.
  • Python Roc Auc Curves Understand how one picture and one number tell you whether a computer's predictions are trustworthy or just lucky guesses.
  • Python Sklearn Learning Curves Why your machine learning model might need more data — or a simpler brain — explained with zero jargon.
  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.