PyTorch Transfer Learning — Deep Dive

Loading Pretrained Models in PyTorch

PyTorch’s torchvision.models and the Hugging Face Hub provide pretrained weights. The modern API uses weights enums:

import torch
from torchvision.models import resnet50, ResNet50_Weights

# Load with ImageNet-V2 weights (latest recommended)
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# Access the preprocessing transform that matches the weights
preprocess = ResNet50_Weights.DEFAULT.transforms()

For Hugging Face models:

from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

Feature Extraction Implementation

Replace the final classification head and freeze everything else:

import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

def build_feature_extractor(num_classes: int) -> nn.Module:
    model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)

    # Freeze all pretrained parameters
    for param in model.parameters():
        param.requires_grad = False

    # Replace classifier (EfficientNet uses model.classifier)
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, num_classes),
    )
    # New classifier params are trainable by default

    return model

Only the new classifier head’s parameters are updated during training. This is fast — typically 5-10x faster than full fine-tuning — and works well with small datasets.

Discriminative Learning Rates

Different layer groups need different learning rates. PyTorch’s optimizer supports per-parameter-group rates:

def get_parameter_groups(model, base_lr=1e-5, head_lr=1e-3):
    """Assign higher learning rates to later layers."""
    # Separate backbone and head parameters
    backbone_params = []
    head_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if "classifier" in name or "fc" in name or "head" in name:
            head_params.append(param)
        else:
            backbone_params.append(param)

    return [
        {"params": backbone_params, "lr": base_lr},
        {"params": head_params, "lr": head_lr},
    ]

optimizer = torch.optim.AdamW(
    get_parameter_groups(model, base_lr=2e-5, head_lr=1e-3),
    weight_decay=0.01,
)

For more granularity, split the backbone into layer groups with geometrically increasing rates:

def layerwise_lr(model, base_lr=1e-6, max_lr=1e-3, num_groups=4):
    """Exponentially increasing LR from early to late layers."""
    named_params = list(model.named_parameters())
    n = len(named_params)
    group_size = n // num_groups

    groups = []
    for i in range(num_groups):
        start = i * group_size
        end = start + group_size if i < num_groups - 1 else n
        lr = base_lr * (max_lr / base_lr) ** (i / (num_groups - 1))

        params = [p for _, p in named_params[start:end] if p.requires_grad]
        if params:
            groups.append({"params": params, "lr": lr})

    return groups

Progressive Unfreezing

Instead of unfreezing all layers at once, gradually unfreeze from top to bottom:

class ProgressiveUnfreezer:
    """Unfreeze one layer group per epoch, starting from the head."""

    def __init__(self, model: nn.Module, layer_groups: list[list[str]]):
        self.model = model
        self.layer_groups = layer_groups  # ordered from head to early
        self.current = 0

        # Freeze everything initially
        for param in model.parameters():
            param.requires_grad = False

    def step(self):
        """Call once per epoch to unfreeze the next group."""
        if self.current >= len(self.layer_groups):
            return

        group_names = self.layer_groups[self.current]
        for name, param in self.model.named_parameters():
            if any(g in name for g in group_names):
                param.requires_grad = True

        self.current += 1

# Usage for ResNet
unfreezer = ProgressiveUnfreezer(model, [
    ["fc"],           # Epoch 0: train only classifier
    ["layer4"],       # Epoch 1: add last residual group
    ["layer3"],       # Epoch 2: add third group
    ["layer2", "layer1"],  # Epoch 3: unfreeze everything
])

for epoch in range(num_epochs):
    unfreezer.step()
    # Rebuild optimizer with current trainable params
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-4,
    )
    train_one_epoch(model, loader, optimizer)

This prevents catastrophic forgetting — early layers retain their pretrained knowledge while later layers adapt to the new task.

Domain Adaptation with Feature Alignment

When source and target domains differ significantly, vanilla fine-tuning may not transfer well. Domain adaptation techniques align feature distributions:

class DomainAdaptiveModel(nn.Module):
    """Uses Maximum Mean Discrepancy to align source and target features."""

    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.output_dim, num_classes)

    def mmd_loss(self, source_features, target_features):
        """Compute MMD between source and target feature distributions."""
        delta = source_features.mean(dim=0) - target_features.mean(dim=0)
        return (delta ** 2).sum()

    def forward(self, source_x, target_x=None):
        source_feat = self.backbone(source_x)
        logits = self.classifier(source_feat)

        if target_x is not None and self.training:
            target_feat = self.backbone(target_x)
            adaptation_loss = self.mmd_loss(source_feat, target_feat)
            return logits, adaptation_loss

        return logits

Handling Class Imbalance in Transfer

Pretrained models are often balanced across classes, but real datasets are imbalanced. Adjust the loss function:

# Compute class weights from training set
from collections import Counter
counts = Counter(dataset.labels)
total = sum(counts.values())
weights = torch.tensor([total / counts[i] for i in range(num_classes)])
weights = weights / weights.sum() * num_classes  # normalize

criterion = nn.CrossEntropyLoss(weight=weights.to(device))

Practical Evaluation Framework

Measure transfer effectiveness systematically:

def evaluate_transfer(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0
    class_correct = {}
    class_total = {}

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            for label, pred in zip(labels.cpu(), predicted.cpu()):
                l = label.item()
                class_total[l] = class_total.get(l, 0) + 1
                if l == pred.item():
                    class_correct[l] = class_correct.get(l, 0) + 1

    overall_acc = correct / total
    per_class = {c: class_correct.get(c, 0) / class_total[c]
                 for c in class_total}

    return overall_acc, per_class

Real-World Benchmarks

Transfer learning performance on medical imaging (Stanford CheXpert dataset):

  • From scratch: 78% AUC after 200 epochs
  • ImageNet pretrained, feature extraction: 83% AUC after 10 epochs
  • ImageNet pretrained, fine-tuned: 87% AUC after 30 epochs
  • RadImageNet pretrained, fine-tuned: 90% AUC after 20 epochs

The domain-specific pretraining (RadImageNet — a medical imaging dataset) outperformed ImageNet despite being 10x smaller, confirming that relevance beats volume.

Anti-Patterns

Learning rate too high for pretrained layers. This is the most common failure. Signs: validation accuracy is worse than feature extraction alone. Fix: reduce backbone learning rate by 10-100x relative to the head.

Not normalizing inputs correctly. Pretrained models expect specific normalization (ImageNet mean/std for vision, specific tokenization for NLP). Using raw data or different normalization silently degrades performance.

Freezing batch normalization incorrectly. When fine-tuning with small batches, BatchNorm statistics become noisy. Set model.eval() for BN layers even during training, or use requires_grad=False on BN parameters specifically.

The one thing to remember: Effective transfer learning is about controlled adaptation — discriminative learning rates, progressive unfreezing, and domain-appropriate pretraining determine whether you get a 5% or a 50% improvement over training from scratch.

pythonmachine-learningpytorch

See Also

  • Python Pytorch Gradient Checkpointing How PyTorch trades a little extra time for massive memory savings when training huge neural networks.
  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.