PyTorch Distributed Training — Deep Dive
DistributedDataParallel (DDP) from Scratch
DDP is the standard for multi-GPU training. Each process owns one GPU, runs a full model replica, and synchronizes gradients via all-reduce after each backward pass.
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size, dataset, model_cls, epochs=10):
setup(rank, world_size)
model = model_cls().to(rank)
ddp_model = DDP(model, device_ids=[rank])
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, batch_size=32, sampler=sampler,
num_workers=4, pin_memory=True)
optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
sampler.set_epoch(epoch) # Critical: ensures different shuffle per epoch
ddp_model.train()
for batch_idx, (data, target) in enumerate(loader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = ddp_model(data)
loss = criterion(output, target)
loss.backward() # Gradients are all-reduced automatically
optimizer.step()
if rank == 0:
print(f"Epoch {epoch}: complete")
cleanup()
Launch with torchrun:
torchrun --nproc_per_node=4 train_script.py
torchrun sets RANK, LOCAL_RANK, WORLD_SIZE, and MASTER_ADDR/MASTER_PORT automatically. Your script reads them from environment variables instead of hardcoding.
The DistributedSampler Contract
DistributedSampler is essential and easy to misuse. Key rules:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank,
shuffle=True, drop_last=True)
- Call
sampler.set_epoch(epoch)every epoch. Without this, every epoch uses the same shuffle order, degrading convergence - Use
drop_last=Truewhen batch sizes must be equal across GPUs (required for synchronized all-reduce) - Don’t use
shuffle=Truein DataLoader when using DistributedSampler — the sampler handles shuffling
FullyShardedDataParallel (FSDP)
FSDP shards model parameters, gradients, and optimizer states across GPUs. Each GPU stores only 1/N of the model. During forward/backward, parameters are gathered just-in-time and released after use.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial
# Define wrapping policy — shard at transformer block level
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock},
)
# Mixed precision config
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
model = MyTransformer().to(rank)
fsdp_model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mp_policy,
auto_wrap_policy=wrap_policy,
device_id=rank,
)
Sharding strategies control the memory-communication tradeoff:
| Strategy | Memory Savings | Communication |
|---|---|---|
FULL_SHARD | Maximum (parameters + gradients + optimizer) | Higher — gather before each forward/backward |
SHARD_GRAD_OP | Moderate (gradients + optimizer only) | Lower — parameters stay replicated |
HYBRID_SHARD | Per-node full shard, cross-node replicate | Balanced for multi-node |
NO_SHARD | None (equivalent to DDP) | Minimum |
Multi-Node Training
For training across multiple machines, the configuration requires a network-accessible master:
# Node 0 (master)
torchrun --nnodes=2 --nproc_per_node=8 \
--node_rank=0 --master_addr=10.0.0.1 --master_port=29500 \
train_script.py
# Node 1
torchrun --nnodes=2 --nproc_per_node=8 \
--node_rank=1 --master_addr=10.0.0.1 --master_port=29500 \
train_script.py
Network bandwidth becomes critical. NCCL uses GPUDirect RDMA when available (InfiniBand), falling back to TCP sockets. On cloud instances, use compute-optimized instances with high-bandwidth interconnects (AWS EFA, GCP GPUDirect).
Gradient Accumulation with DDP
Simulate larger batch sizes without more GPUs by accumulating gradients:
accumulation_steps = 4
for batch_idx, (data, target) in enumerate(loader):
# Disable gradient sync for accumulation steps
context = ddp_model.no_sync() if (batch_idx + 1) % accumulation_steps != 0 \
else nullcontext()
with context:
output = ddp_model(data)
loss = criterion(output, target) / accumulation_steps
loss.backward()
if (batch_idx + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
The no_sync() context manager skips all-reduce during accumulation steps, reducing communication by accumulation_steps×.
Checkpointing and Fault Tolerance
In long-running distributed training, failures are expected. Save checkpoints that can restart from any point:
def save_checkpoint(model, optimizer, epoch, path):
if dist.get_rank() == 0: # Only rank 0 saves
state = {
"model": model.module.state_dict(), # .module unwraps DDP
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}
torch.save(state, path)
dist.barrier() # Wait for save to complete
def load_checkpoint(model, optimizer, path, device):
map_location = {"cuda:0": f"cuda:{dist.get_rank()}"}
state = torch.load(path, map_location=map_location)
model.module.load_state_dict(state["model"])
optimizer.load_state_dict(state["optimizer"])
return state["epoch"]
For FSDP, use StateDictType.FULL_STATE_DICT or SHARDED_STATE_DICT:
from torch.distributed.fsdp import StateDictType
with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
state_dict = fsdp_model.state_dict()
if rank == 0:
torch.save(state_dict, "checkpoint.pt")
Profiling Distributed Performance
Identify communication bottlenecks:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=torch.profiler.tensorboard_trace_handler("./dist_profile"),
) as prof:
for step, (data, target) in enumerate(loader):
if step >= 5:
break
output = ddp_model(data.to(rank))
loss = criterion(output, target.to(rank))
loss.backward()
optimizer.step()
optimizer.zero_grad()
prof.step()
Look for: NCCL all-reduce taking >30% of step time (indicates communication bottleneck), uneven GPU utilization (indicates load imbalance), and long idle periods (pipeline bubbles).
Common Pitfalls
Deadlocks from conditional execution. If only some ranks execute a collective operation (like all-reduce), the others hang forever. Ensure all ranks follow identical control flow for distributed operations.
Forgetting sampler.set_epoch(). Silently degrades convergence — every epoch sees identical data order, effectively reducing dataset diversity.
Mixed requires_grad across ranks. If one rank freezes a parameter that another doesn’t, the all-reduce sizes mismatch and training crashes. Ensure identical model architecture and frozen parameters across all ranks.
The one thing to remember: Distributed training is reliable when you treat it as a systems engineering problem — deterministic samplers, fault-tolerant checkpointing, and communication-aware batching matter as much as the model architecture itself.
See Also
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
- Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
- Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'