PyTorch Gradient Checkpointing — Deep Dive
The Activation Memory Bottleneck
During a standard forward pass, PyTorch’s autograd engine stores every intermediate tensor needed for backward computation. For a Transformer with L layers, hidden dimension H, sequence length S, and batch size B, the activation memory scales as O(L × B × S × H). A 24-layer Transformer with H=1024, S=512, B=32 in FP32 stores roughly 12 GB of activations alone — before parameters or optimizer states.
Gradient checkpointing attacks this specific cost by selectively discarding activations and recomputing them during the backward pass.
PyTorch’s Checkpoint API
The core function is torch.utils.checkpoint.checkpoint:
import torch
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformerBlock(torch.nn.Module):
def __init__(self, d_model, nhead, dim_ff):
super().__init__()
self.attn = torch.nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.ff = torch.nn.Sequential(
torch.nn.Linear(d_model, dim_ff),
torch.nn.GELU(),
torch.nn.Linear(dim_ff, d_model),
)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model)
def _forward_block(self, x):
attn_out, _ = self.attn(x, x, x)
x = self.norm1(x + attn_out)
x = self.norm2(x + self.ff(x))
return x
def forward(self, x):
# Wrap the computation in checkpoint
return checkpoint(self._forward_block, x, use_reentrant=False)
Key details about checkpoint():
use_reentrant=False(recommended since PyTorch 2.0) — uses a newer, safer implementation that works correctly withtorch.autograd.gradand nested checkpointing- The wrapped function must not have side effects or use global mutable state, because it runs twice (once in forward, once during backward recomputation)
- Inputs to the checkpointed function must be tensors or have tensors as leaves — non-tensor arguments won’t trigger recomputation
Checkpoint Sequential for CNNs
For sequential models, checkpoint_sequential is more convenient:
from torch.utils.checkpoint import checkpoint_sequential
class DeepCNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.features = torch.nn.Sequential(
*[self._make_block(64, 64) for _ in range(20)]
)
self.classifier = torch.nn.Linear(64, 10)
def _make_block(self, in_ch, out_ch):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, 3, padding=1),
torch.nn.BatchNorm2d(out_ch),
torch.nn.ReLU(inplace=False), # inplace=False is required!
)
def forward(self, x):
# Split into 4 segments — checkpoint each
x = checkpoint_sequential(self.features, segments=4, input=x,
use_reentrant=False)
return self.classifier(x.mean(dim=[2, 3]))
The segments parameter controls granularity. More segments = more checkpoints = more memory saved but more recomputation.
Memory Profiling
Measuring the actual impact requires careful profiling:
import torch
from torch.profiler import profile, ProfilerActivity
model = build_model().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
data = torch.randn(16, 512, 1024, device="cuda")
with profile(
activities=[ProfilerActivity.CUDA],
profile_memory=True,
record_shapes=True,
) as prof:
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=20))
Compare peak memory with and without checkpointing. You can also use torch.cuda.max_memory_allocated() for a quick check:
torch.cuda.reset_peak_memory_stats()
# ... run training step ...
peak_mb = torch.cuda.max_memory_allocated() / 1024**2
print(f"Peak GPU memory: {peak_mb:.0f} MB")
Custom Checkpointing Strategies
Not all layers are equal. Attention layers store O(B × S²) activations for the attention matrix, while feedforward layers store O(B × S × H). A smart strategy checkpoints attention-heavy blocks more aggressively:
class SelectiveCheckpointTransformer(torch.nn.Module):
def __init__(self, layers, checkpoint_every=2):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
self.checkpoint_every = checkpoint_every
def forward(self, x):
for i, layer in enumerate(self.layers):
if self.training and i % self.checkpoint_every == 0:
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
Checkpointing every 2nd or 3rd layer often captures 80% of the memory savings with only 10–15% compute overhead, compared to checkpointing every layer which saves more memory but at 30%+ overhead.
Interaction with Mixed Precision
Gradient checkpointing combines naturally with torch.amp:
scaler = torch.amp.GradScaler("cuda")
for batch in dataloader:
optimizer.zero_grad()
with torch.amp.autocast("cuda", dtype=torch.float16):
output = checkpointed_model(batch)
loss = criterion(output, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
When autocast is active during the original forward pass, the recomputed forward pass inside backward also runs under autocast. The activations are stored in FP16, so memory savings stack: FP16 halves per-element cost, and checkpointing reduces the element count.
Gotchas and Debugging
In-place operations break checkpointing. If a layer uses inplace=True (like ReLU(inplace=True)), the recomputation produces different values because the original input was modified. Always use inplace=False in checkpointed regions.
Dropout behavior. Random operations like dropout must produce the same values during recomputation. PyTorch’s non-reentrant checkpoint implementation (with use_reentrant=False) handles this by preserving the RNG state. The older reentrant implementation can produce incorrect gradients with dropout.
BatchNorm running statistics. BatchNorm updates running mean/variance during forward. Recomputation triggers a second update, skewing statistics. Solutions: use torch.no_grad() for running stats during recomputation, switch to LayerNorm, or freeze BatchNorm during checkpointed training.
Debugging gradient issues. If gradients look wrong after adding checkpointing, use torch.autograd.gradcheck on a small model to verify numerical correctness:
from torch.autograd import gradcheck
small_model = SmallCheckpointedModel().double().cuda()
x = torch.randn(2, 4, requires_grad=True, dtype=torch.float64, device="cuda")
assert gradcheck(small_model, (x,), eps=1e-6, atol=1e-4)
Real-World Impact
Meta’s LLaMA training used gradient checkpointing as a core technique to fit 65B parameter models across GPU clusters. Hugging Face reports that enabling gradient_checkpointing=True in their Trainer reduces memory by 50–60% for typical fine-tuning workflows, making it possible to fine-tune a 7B parameter model on a single A100-40GB.
The technique’s simplicity is its strength: no algorithmic changes, no approximations, no hyperparameter tuning. Toggle it on, accept the 20–30% slowdown, and get dramatically more headroom.
The one thing to remember: Gradient checkpointing is the single most effective memory optimization for training — it’s mathematically exact, easy to implement, and composes cleanly with mixed precision and distributed training.
See Also
- Python Pytorch Transfer Learning Why training an AI from scratch is wasteful when you can borrow knowledge from a model that already learned.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
- Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.