Scikit-Learn Learning Curves — Deep Dive
Technical foundation
Learning curves are rooted in the bias-variance decomposition of generalization error. As training set size increases, bias typically stays constant (determined by model capacity) while variance decreases (more data constrains the hypothesis space). The learning curve visualizes this tradeoff empirically.
Formally, for a model class H and training set size n:
- Expected training error generally increases with
n(harder to memorize more examples) - Expected test error generally decreases with
n(better generalization from more evidence)
The two curves converge toward the irreducible error (Bayes error) as n → ∞, but the rate of convergence depends on model complexity.
Implementation with scikit-learn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import learning_curve
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
# Generate synthetic dataset
X, y = make_classification(
n_samples=5000, n_features=20, n_informative=10,
n_redundant=5, random_state=42
)
# Compute learning curve
train_sizes, train_scores, val_scores = learning_curve(
estimator=RandomForestClassifier(n_estimators=100, random_state=42),
X=X, y=y,
train_sizes=np.linspace(0.1, 1.0, 10),
cv=5,
scoring='f1_weighted',
n_jobs=-1,
shuffle=True,
random_state=42
)
# Plot with confidence bands
train_mean = train_scores.mean(axis=1)
train_std = train_scores.std(axis=1)
val_mean = val_scores.mean(axis=1)
val_std = val_scores.std(axis=1)
fig, ax = plt.subplots(figsize=(10, 6))
ax.fill_between(train_sizes, train_mean - train_std, train_mean + train_std, alpha=0.1, color='blue')
ax.fill_between(train_sizes, val_mean - val_std, val_mean + val_std, alpha=0.1, color='orange')
ax.plot(train_sizes, train_mean, 'o-', color='blue', label='Training score')
ax.plot(train_sizes, val_mean, 'o-', color='orange', label='Validation score')
ax.set_xlabel('Training Set Size')
ax.set_ylabel('F1 Score (weighted)')
ax.set_title('Learning Curve — Random Forest')
ax.legend(loc='lower right')
ax.set_ylim(0, 1.05)
plt.tight_layout()
plt.savefig('learning_curve.png', dpi=150)
Key parameters and their effects
train_sizes: Use np.linspace(0.1, 1.0, 10) for a smooth curve. For large datasets, consider logarithmic spacing with np.logspace to focus resolution on the critical low-data regime where curves change fastest.
cv: Stratified K-Fold (default for classifiers) preserves class distribution at each split. For time series, use TimeSeriesSplit to avoid data leakage. With small datasets, increase folds (e.g., 10) to reduce variance in score estimates.
scoring: Match this to your business metric. Using accuracy on imbalanced data produces misleadingly flat curves. Prefer F1, precision-recall AUC, or balanced accuracy when class distributions are skewed.
n_jobs=-1: Learning curve computation is embarrassingly parallel across CV folds and training sizes. On a 16-core machine, this can reduce wall time by 10-14x. Monitor memory — each job clones the estimator and a data subset.
shuffle: Always set shuffle=True with a fixed random_state. Without shuffling, results depend on data ordering, which can produce jagged, unreproducible curves.
Advanced patterns
Comparing model complexity
Plot learning curves for multiple models on the same axes to visualize the bias-variance tradeoff across model families:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import GradientBoostingClassifier
models = {
'Logistic Regression': LogisticRegression(max_iter=1000),
'SVM (RBF)': SVC(kernel='rbf', gamma='scale'),
'Gradient Boosting': GradientBoostingClassifier(n_estimators=100),
}
fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)
for ax, (name, model) in zip(axes, models.items()):
train_sizes, train_scores, val_scores = learning_curve(
model, X, y, train_sizes=np.linspace(0.1, 1.0, 8),
cv=5, scoring='f1_weighted', n_jobs=-1, random_state=42
)
ax.plot(train_sizes, train_scores.mean(axis=1), 'o-', label='Train')
ax.plot(train_sizes, val_scores.mean(axis=1), 'o-', label='Validation')
ax.set_title(name)
ax.set_xlabel('Training Size')
ax.legend()
axes[0].set_ylabel('F1 Score')
plt.tight_layout()
This comparison reveals which model benefits most from additional data — critical for deciding where to invest data collection effort.
Learning curves with pipelines
Since learning_curve accepts any estimator, you can pass full pipelines:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
pipe = Pipeline([
('scaler', StandardScaler()),
('pca', PCA(n_components=10)),
('clf', RandomForestClassifier(n_estimators=100, random_state=42)),
])
train_sizes, train_scores, val_scores = learning_curve(
pipe, X, y, train_sizes=np.linspace(0.1, 1.0, 8),
cv=5, scoring='f1_weighted', n_jobs=-1, random_state=42
)
This ensures preprocessing is refit at each training size — no data leakage from fitting scalers on the full dataset.
Extrapolating data requirements
Fit a power-law model to the validation curve to estimate how much more data you’d need to reach a target score:
from scipy.optimize import curve_fit
def power_law(n, a, b, c):
return a - b * n ** (-c)
popt, _ = curve_fit(power_law, train_sizes, val_scores.mean(axis=1), p0=[0.95, 1.0, 0.5])
target_score = 0.92
# Solve: target = a - b * n^(-c) → n = (b / (a - target))^(1/c)
estimated_n = (popt[1] / (popt[0] - target_score)) ** (1 / popt[2])
print(f"Estimated samples needed for {target_score} F1: {estimated_n:.0f}")
This is approximate but useful for planning data annotation budgets.
Subtle curve behaviors
Dipping validation curve: Sometimes the validation score temporarily drops before recovering. This happens when the model starts capturing noise in an intermediate data range before the signal dominates at larger sizes. Don’t panic at a temporary dip — extend the curve.
Oscillating training score: High variance in training scores across CV folds indicates sensitive model initialization or features with varying relevance across data subsets. Consider more folds or repeated cross-validation.
Crossing curves: If the validation score exceeds training score at very small sizes, it’s typically an artifact of regularization working well on tiny, easy validation sets. This resolves at larger sizes.
Performance considerations for large datasets
For datasets with millions of rows, learning_curve can be prohibitively slow. Strategies:
- Subsample first: Run the curve on a random 10-20% subset, then validate key points on the full set
- Reduce CV folds: 3-fold instead of 5-fold cuts compute time by 40%
- Use faster estimators: Train with
HistGradientBoostingClassifier(GPU-friendly, built-in early stopping) instead ofGradientBoostingClassifier - Coarse-to-fine: Start with 4-5 training sizes to identify the interesting region, then add resolution only where the curve is changing
Real-world example: fraud detection
A payment processor trained a gradient boosting model on 500K labeled transactions. The learning curve showed:
- Training F1: 0.99 (near-perfect)
- Validation F1: 0.82, still climbing at 500K samples
The gap and upward slope indicated high variance — the model would benefit from more labeled data. They invested in labeling 200K additional transactions and the validation F1 reached 0.89, closing the gap significantly. Without the learning curve, they might have switched models instead of collecting data — a more expensive and less effective strategy.
Tradeoffs
| Approach | Pros | Cons |
|---|---|---|
| Learning curve | Visual, intuitive, catches bias/variance | Computationally expensive for large data |
| Validation curve | Shows hyperparameter sensitivity | Single hyperparameter at a time |
| Cross-validation score alone | Fast, single number | No insight into data sufficiency |
| Hold-out test set | Simple, unbiased final eval | Wastes data, no trend information |
One thing to remember: Learning curves are the cheapest experiment in ML — they tell you whether to invest in more data or a better model before you commit resources to either path.
See Also
- Python Confusion Matrix See how a simple grid of right and wrong answers reveals what your computer is actually getting confused about.
- Python Cross Validation Find out why testing a computer's homework on different practice sets keeps it from cheating.
- Python Model Evaluation Metrics Discover why asking 'how good is my model?' needs more than one number to get an honest answer.
- Python Roc Auc Curves Understand how one picture and one number tell you whether a computer's predictions are trustworthy or just lucky guesses.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.