Confusion Matrix in Python — Deep Dive
Building a Confusion Matrix in Scikit-Learn
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
print(cm)
The output is a NumPy array where cm[i][j] is the count of samples with true label i predicted as j.
Visualization with Matplotlib
Basic Heatmap
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
fig, ax = plt.subplots(figsize=(8, 6))
disp = ConfusionMatrixDisplay(
confusion_matrix=cm,
display_labels=["Setosa", "Versicolor", "Virginica"],
)
disp.plot(cmap="Blues", ax=ax, values_format="d")
ax.set_title("Iris Classification — Confusion Matrix")
plt.tight_layout()
plt.savefig("confusion_matrix.png", dpi=150)
Normalized Heatmap
Normalizing by rows shows recall per class; normalizing by columns shows precision per class:
cm_normalized = confusion_matrix(y_test, y_pred, normalize="true")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
ConfusionMatrixDisplay(cm, display_labels=["Setosa", "Versicolor", "Virginica"]).plot(
ax=axes[0], cmap="Blues", values_format="d"
)
axes[0].set_title("Raw Counts")
ConfusionMatrixDisplay(cm_normalized, display_labels=["Setosa", "Versicolor", "Virginica"]).plot(
ax=axes[1], cmap="Oranges", values_format=".2%"
)
axes[1].set_title("Normalized (Row = Recall)")
plt.tight_layout()
plt.savefig("confusion_matrix_comparison.png", dpi=150)
Extracting Metrics from the Matrix
Every standard classification metric can be derived directly from the confusion matrix:
import numpy as np
def metrics_from_cm(cm):
"""Extract per-class and aggregate metrics from a confusion matrix."""
n_classes = cm.shape[0]
tp = np.diag(cm)
fp = cm.sum(axis=0) - tp
fn = cm.sum(axis=1) - tp
tn = cm.sum() - (tp + fp + fn)
precision = tp / (tp + fp + 1e-10)
recall = tp / (tp + fn + 1e-10)
f1 = 2 * precision * recall / (precision + recall + 1e-10)
return {
"per_class_precision": precision,
"per_class_recall": recall,
"per_class_f1": f1,
"macro_f1": f1.mean(),
"accuracy": tp.sum() / cm.sum(),
}
metrics = metrics_from_cm(cm)
for key, val in metrics.items():
print(f"{key}: {val}")
Multi-Label Confusion Matrices
For multi-label classification where each sample can have multiple labels:
from sklearn.metrics import multilabel_confusion_matrix
y_true_ml = [[1, 0, 1], [0, 1, 1], [1, 1, 0]]
y_pred_ml = [[1, 0, 0], [0, 1, 1], [1, 0, 0]]
mcm = multilabel_confusion_matrix(y_true_ml, y_pred_ml)
# Returns one 2x2 matrix per label
for i, label_cm in enumerate(mcm):
print(f"Label {i}:\n{label_cm}\n")
Cost-Sensitive Analysis
Not all errors are equal. In fraud detection, missing fraud (FN) costs far more than a false alarm (FP). The confusion matrix enables cost-weighted evaluation:
# Cost matrix: cost[i][j] = cost of predicting j when true label is i
cost_matrix = np.array([
[0, 10], # True Negative costs 0, False Positive costs $10
[500, 0], # False Negative costs $500, True Positive costs 0
])
total_cost = np.sum(cm * cost_matrix)
cost_per_sample = total_cost / cm.sum()
print(f"Total cost: ${total_cost:,.0f}")
print(f"Cost per prediction: ${cost_per_sample:.2f}")
This transforms the confusion matrix from a statistical tool into a business decision tool. You can now optimize the classification threshold to minimize total cost:
from sklearn.metrics import confusion_matrix
best_cost = float("inf")
best_threshold = 0.5
for threshold in np.arange(0.1, 0.95, 0.01):
y_pred_t = (y_proba >= threshold).astype(int)
cm_t = confusion_matrix(y_true, y_pred_t)
cost_t = np.sum(cm_t * cost_matrix)
if cost_t < best_cost:
best_cost = cost_t
best_threshold = threshold
print(f"Optimal threshold: {best_threshold:.2f}, Cost: ${best_cost:,.0f}")
Error Analysis Workflow
The confusion matrix is the starting point for systematic error analysis:
Step 1: Identify Problem Pairs
# Find the largest off-diagonal values
cm_no_diag = cm.copy()
np.fill_diagonal(cm_no_diag, 0)
worst_pair = np.unravel_index(cm_no_diag.argmax(), cm_no_diag.shape)
print(f"Most confused: true={worst_pair[0]}, predicted={worst_pair[1]}")
Step 2: Inspect Misclassified Samples
# Get indices of samples confused between classes 1 and 2
misclassified = np.where((y_test == worst_pair[0]) & (y_pred == worst_pair[1]))[0]
print(f"Number of errors: {len(misclassified)}")
print(f"Sample indices: {misclassified[:10]}")
# Inspect feature distributions of misclassified vs. correctly classified
import pandas as pd
df_test = pd.DataFrame(X_test, columns=feature_names)
df_test["true"] = y_test
df_test["pred"] = y_pred
df_test["correct"] = y_test == y_pred
# Compare feature means
print(df_test[df_test["true"] == worst_pair[0]].groupby("correct").mean())
Step 3: Build Targeted Fixes
Once you know which features differ between correctly and incorrectly classified samples, you can:
- Engineer features that distinguish the confused classes.
- Collect more training data for the problem pair.
- Use class-specific thresholds.
- Train a specialized sub-model for the ambiguous pair.
Confusion Matrix Over Time
In production, track the confusion matrix on a rolling window to detect drift:
from collections import deque
class RollingConfusionMatrix:
def __init__(self, window_size=1000, n_classes=2):
self.buffer = deque(maxlen=window_size)
self.n_classes = n_classes
def update(self, y_true_batch, y_pred_batch):
for yt, yp in zip(y_true_batch, y_pred_batch):
self.buffer.append((yt, yp))
def get_matrix(self):
if not self.buffer:
return np.zeros((self.n_classes, self.n_classes), dtype=int)
yt, yp = zip(*self.buffer)
return confusion_matrix(yt, yp, labels=range(self.n_classes))
Alert when off-diagonal proportions exceed historical baselines.
Advanced Visualization: Sankey Diagrams
For large multi-class problems (10+ classes), heatmaps get crowded. A Sankey diagram shows flows from true labels to predicted labels, making major confusion patterns obvious:
import plotly.graph_objects as go
labels = class_names + [f"Pred: {c}" for c in class_names]
source, target, value = [], [], []
for i in range(n_classes):
for j in range(n_classes):
if cm[i][j] > 0:
source.append(i)
target.append(n_classes + j)
value.append(int(cm[i][j]))
fig = go.Figure(go.Sankey(
node=dict(label=labels),
link=dict(source=source, target=target, value=value),
))
fig.update_layout(title="Classification Flow: True → Predicted")
fig.show()
Common Pitfalls
- Transposing rows and columns: Scikit-learn convention is rows = true, columns = predicted. Some libraries use the opposite. Always verify.
- Ignoring normalization: Raw counts mislead when class sizes differ. Normalize before drawing conclusions about per-class performance.
- Only looking at the diagonal: The diagonal tells you what went right; the off-diagonal tells you what needs fixing.
- Forgetting to include all labels: If a class has zero test samples, pass
labels=explicitly to avoid a missing row/column.
One thing to remember: The confusion matrix is not just a reporting tool — it is a diagnostic instrument. The errors it reveals are the roadmap for your next round of improvements.
See Also
- Python Cross Validation Find out why testing a computer's homework on different practice sets keeps it from cheating.
- Python Model Evaluation Metrics Discover why asking 'how good is my model?' needs more than one number to get an honest answer.
- Python Roc Auc Curves Understand how one picture and one number tell you whether a computer's predictions are trustworthy or just lucky guesses.
- Python Sklearn Learning Curves Why your machine learning model might need more data — or a simpler brain — explained with zero jargon.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.