PyTorch Transfer Learning — Deep Dive
Loading Pretrained Models in PyTorch
PyTorch’s torchvision.models and the Hugging Face Hub provide pretrained weights. The modern API uses weights enums:
import torch
from torchvision.models import resnet50, ResNet50_Weights
# Load with ImageNet-V2 weights (latest recommended)
model = resnet50(weights=ResNet50_Weights.DEFAULT)
# Access the preprocessing transform that matches the weights
preprocess = ResNet50_Weights.DEFAULT.transforms()
For Hugging Face models:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
Feature Extraction Implementation
Replace the final classification head and freeze everything else:
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
def build_feature_extractor(num_classes: int) -> nn.Module:
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
# Freeze all pretrained parameters
for param in model.parameters():
param.requires_grad = False
# Replace classifier (EfficientNet uses model.classifier)
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, num_classes),
)
# New classifier params are trainable by default
return model
Only the new classifier head’s parameters are updated during training. This is fast — typically 5-10x faster than full fine-tuning — and works well with small datasets.
Discriminative Learning Rates
Different layer groups need different learning rates. PyTorch’s optimizer supports per-parameter-group rates:
def get_parameter_groups(model, base_lr=1e-5, head_lr=1e-3):
"""Assign higher learning rates to later layers."""
# Separate backbone and head parameters
backbone_params = []
head_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if "classifier" in name or "fc" in name or "head" in name:
head_params.append(param)
else:
backbone_params.append(param)
return [
{"params": backbone_params, "lr": base_lr},
{"params": head_params, "lr": head_lr},
]
optimizer = torch.optim.AdamW(
get_parameter_groups(model, base_lr=2e-5, head_lr=1e-3),
weight_decay=0.01,
)
For more granularity, split the backbone into layer groups with geometrically increasing rates:
def layerwise_lr(model, base_lr=1e-6, max_lr=1e-3, num_groups=4):
"""Exponentially increasing LR from early to late layers."""
named_params = list(model.named_parameters())
n = len(named_params)
group_size = n // num_groups
groups = []
for i in range(num_groups):
start = i * group_size
end = start + group_size if i < num_groups - 1 else n
lr = base_lr * (max_lr / base_lr) ** (i / (num_groups - 1))
params = [p for _, p in named_params[start:end] if p.requires_grad]
if params:
groups.append({"params": params, "lr": lr})
return groups
Progressive Unfreezing
Instead of unfreezing all layers at once, gradually unfreeze from top to bottom:
class ProgressiveUnfreezer:
"""Unfreeze one layer group per epoch, starting from the head."""
def __init__(self, model: nn.Module, layer_groups: list[list[str]]):
self.model = model
self.layer_groups = layer_groups # ordered from head to early
self.current = 0
# Freeze everything initially
for param in model.parameters():
param.requires_grad = False
def step(self):
"""Call once per epoch to unfreeze the next group."""
if self.current >= len(self.layer_groups):
return
group_names = self.layer_groups[self.current]
for name, param in self.model.named_parameters():
if any(g in name for g in group_names):
param.requires_grad = True
self.current += 1
# Usage for ResNet
unfreezer = ProgressiveUnfreezer(model, [
["fc"], # Epoch 0: train only classifier
["layer4"], # Epoch 1: add last residual group
["layer3"], # Epoch 2: add third group
["layer2", "layer1"], # Epoch 3: unfreeze everything
])
for epoch in range(num_epochs):
unfreezer.step()
# Rebuild optimizer with current trainable params
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=1e-4,
)
train_one_epoch(model, loader, optimizer)
This prevents catastrophic forgetting — early layers retain their pretrained knowledge while later layers adapt to the new task.
Domain Adaptation with Feature Alignment
When source and target domains differ significantly, vanilla fine-tuning may not transfer well. Domain adaptation techniques align feature distributions:
class DomainAdaptiveModel(nn.Module):
"""Uses Maximum Mean Discrepancy to align source and target features."""
def __init__(self, backbone, num_classes):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.output_dim, num_classes)
def mmd_loss(self, source_features, target_features):
"""Compute MMD between source and target feature distributions."""
delta = source_features.mean(dim=0) - target_features.mean(dim=0)
return (delta ** 2).sum()
def forward(self, source_x, target_x=None):
source_feat = self.backbone(source_x)
logits = self.classifier(source_feat)
if target_x is not None and self.training:
target_feat = self.backbone(target_x)
adaptation_loss = self.mmd_loss(source_feat, target_feat)
return logits, adaptation_loss
return logits
Handling Class Imbalance in Transfer
Pretrained models are often balanced across classes, but real datasets are imbalanced. Adjust the loss function:
# Compute class weights from training set
from collections import Counter
counts = Counter(dataset.labels)
total = sum(counts.values())
weights = torch.tensor([total / counts[i] for i in range(num_classes)])
weights = weights / weights.sum() * num_classes # normalize
criterion = nn.CrossEntropyLoss(weight=weights.to(device))
Practical Evaluation Framework
Measure transfer effectiveness systematically:
def evaluate_transfer(model, val_loader, device):
model.eval()
correct = 0
total = 0
class_correct = {}
class_total = {}
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
for label, pred in zip(labels.cpu(), predicted.cpu()):
l = label.item()
class_total[l] = class_total.get(l, 0) + 1
if l == pred.item():
class_correct[l] = class_correct.get(l, 0) + 1
overall_acc = correct / total
per_class = {c: class_correct.get(c, 0) / class_total[c]
for c in class_total}
return overall_acc, per_class
Real-World Benchmarks
Transfer learning performance on medical imaging (Stanford CheXpert dataset):
- From scratch: 78% AUC after 200 epochs
- ImageNet pretrained, feature extraction: 83% AUC after 10 epochs
- ImageNet pretrained, fine-tuned: 87% AUC after 30 epochs
- RadImageNet pretrained, fine-tuned: 90% AUC after 20 epochs
The domain-specific pretraining (RadImageNet — a medical imaging dataset) outperformed ImageNet despite being 10x smaller, confirming that relevance beats volume.
Anti-Patterns
Learning rate too high for pretrained layers. This is the most common failure. Signs: validation accuracy is worse than feature extraction alone. Fix: reduce backbone learning rate by 10-100x relative to the head.
Not normalizing inputs correctly. Pretrained models expect specific normalization (ImageNet mean/std for vision, specific tokenization for NLP). Using raw data or different normalization silently degrades performance.
Freezing batch normalization incorrectly. When fine-tuning with small batches, BatchNorm statistics become noisy. Set model.eval() for BN layers even during training, or use requires_grad=False on BN parameters specifically.
The one thing to remember: Effective transfer learning is about controlled adaptation — discriminative learning rates, progressive unfreezing, and domain-appropriate pretraining determine whether you get a 5% or a 50% improvement over training from scratch.
See Also
- Python Pytorch Gradient Checkpointing How PyTorch trades a little extra time for massive memory savings when training huge neural networks.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
- Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.