PyTorch Gradient Checkpointing — Deep Dive

The Activation Memory Bottleneck

During a standard forward pass, PyTorch’s autograd engine stores every intermediate tensor needed for backward computation. For a Transformer with L layers, hidden dimension H, sequence length S, and batch size B, the activation memory scales as O(L × B × S × H). A 24-layer Transformer with H=1024, S=512, B=32 in FP32 stores roughly 12 GB of activations alone — before parameters or optimizer states.

Gradient checkpointing attacks this specific cost by selectively discarding activations and recomputing them during the backward pass.

PyTorch’s Checkpoint API

The core function is torch.utils.checkpoint.checkpoint:

import torch
from torch.utils.checkpoint import checkpoint

class CheckpointedTransformerBlock(torch.nn.Module):
    def __init__(self, d_model, nhead, dim_ff):
        super().__init__()
        self.attn = torch.nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(d_model, dim_ff),
            torch.nn.GELU(),
            torch.nn.Linear(dim_ff, d_model),
        )
        self.norm1 = torch.nn.LayerNorm(d_model)
        self.norm2 = torch.nn.LayerNorm(d_model)

    def _forward_block(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.ff(x))
        return x

    def forward(self, x):
        # Wrap the computation in checkpoint
        return checkpoint(self._forward_block, x, use_reentrant=False)

Key details about checkpoint():

  • use_reentrant=False (recommended since PyTorch 2.0) — uses a newer, safer implementation that works correctly with torch.autograd.grad and nested checkpointing
  • The wrapped function must not have side effects or use global mutable state, because it runs twice (once in forward, once during backward recomputation)
  • Inputs to the checkpointed function must be tensors or have tensors as leaves — non-tensor arguments won’t trigger recomputation

Checkpoint Sequential for CNNs

For sequential models, checkpoint_sequential is more convenient:

from torch.utils.checkpoint import checkpoint_sequential

class DeepCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.features = torch.nn.Sequential(
            *[self._make_block(64, 64) for _ in range(20)]
        )
        self.classifier = torch.nn.Linear(64, 10)

    def _make_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, 3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(inplace=False),  # inplace=False is required!
        )

    def forward(self, x):
        # Split into 4 segments — checkpoint each
        x = checkpoint_sequential(self.features, segments=4, input=x,
                                  use_reentrant=False)
        return self.classifier(x.mean(dim=[2, 3]))

The segments parameter controls granularity. More segments = more checkpoints = more memory saved but more recomputation.

Memory Profiling

Measuring the actual impact requires careful profiling:

import torch
from torch.profiler import profile, ProfilerActivity

model = build_model().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
data = torch.randn(16, 512, 1024, device="cuda")

with profile(
    activities=[ProfilerActivity.CUDA],
    profile_memory=True,
    record_shapes=True,
) as prof:
    output = model(data)
    loss = output.sum()
    loss.backward()
    optimizer.step()

print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=20))

Compare peak memory with and without checkpointing. You can also use torch.cuda.max_memory_allocated() for a quick check:

torch.cuda.reset_peak_memory_stats()
# ... run training step ...
peak_mb = torch.cuda.max_memory_allocated() / 1024**2
print(f"Peak GPU memory: {peak_mb:.0f} MB")

Custom Checkpointing Strategies

Not all layers are equal. Attention layers store O(B × S²) activations for the attention matrix, while feedforward layers store O(B × S × H). A smart strategy checkpoints attention-heavy blocks more aggressively:

class SelectiveCheckpointTransformer(torch.nn.Module):
    def __init__(self, layers, checkpoint_every=2):
        super().__init__()
        self.layers = torch.nn.ModuleList(layers)
        self.checkpoint_every = checkpoint_every

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            if self.training and i % self.checkpoint_every == 0:
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        return x

Checkpointing every 2nd or 3rd layer often captures 80% of the memory savings with only 10–15% compute overhead, compared to checkpointing every layer which saves more memory but at 30%+ overhead.

Interaction with Mixed Precision

Gradient checkpointing combines naturally with torch.amp:

scaler = torch.amp.GradScaler("cuda")

for batch in dataloader:
    optimizer.zero_grad()
    with torch.amp.autocast("cuda", dtype=torch.float16):
        output = checkpointed_model(batch)
        loss = criterion(output, targets)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

When autocast is active during the original forward pass, the recomputed forward pass inside backward also runs under autocast. The activations are stored in FP16, so memory savings stack: FP16 halves per-element cost, and checkpointing reduces the element count.

Gotchas and Debugging

In-place operations break checkpointing. If a layer uses inplace=True (like ReLU(inplace=True)), the recomputation produces different values because the original input was modified. Always use inplace=False in checkpointed regions.

Dropout behavior. Random operations like dropout must produce the same values during recomputation. PyTorch’s non-reentrant checkpoint implementation (with use_reentrant=False) handles this by preserving the RNG state. The older reentrant implementation can produce incorrect gradients with dropout.

BatchNorm running statistics. BatchNorm updates running mean/variance during forward. Recomputation triggers a second update, skewing statistics. Solutions: use torch.no_grad() for running stats during recomputation, switch to LayerNorm, or freeze BatchNorm during checkpointed training.

Debugging gradient issues. If gradients look wrong after adding checkpointing, use torch.autograd.gradcheck on a small model to verify numerical correctness:

from torch.autograd import gradcheck

small_model = SmallCheckpointedModel().double().cuda()
x = torch.randn(2, 4, requires_grad=True, dtype=torch.float64, device="cuda")
assert gradcheck(small_model, (x,), eps=1e-6, atol=1e-4)

Real-World Impact

Meta’s LLaMA training used gradient checkpointing as a core technique to fit 65B parameter models across GPU clusters. Hugging Face reports that enabling gradient_checkpointing=True in their Trainer reduces memory by 50–60% for typical fine-tuning workflows, making it possible to fine-tune a 7B parameter model on a single A100-40GB.

The technique’s simplicity is its strength: no algorithmic changes, no approximations, no hyperparameter tuning. Toggle it on, accept the 20–30% slowdown, and get dramatically more headroom.

The one thing to remember: Gradient checkpointing is the single most effective memory optimization for training — it’s mathematically exact, easy to implement, and composes cleanly with mixed precision and distributed training.

pythonmachine-learningpytorch

See Also

  • Python Pytorch Transfer Learning Why training an AI from scratch is wasteful when you can borrow knowledge from a model that already learned.
  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.