PyTorch Distributed Training — Deep Dive

DistributedDataParallel (DDP) from Scratch

DDP is the standard for multi-GPU training. Each process owns one GPU, runs a full model replica, and synchronizes gradients via all-reduce after each backward pass.

import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, dataset, model_cls, epochs=10):
    setup(rank, world_size)

    model = model_cls().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=32, sampler=sampler,
                        num_workers=4, pin_memory=True)

    optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # Critical: ensures different shuffle per epoch
        ddp_model.train()

        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(rank), target.to(rank)
            optimizer.zero_grad()
            output = ddp_model(data)
            loss = criterion(output, target)
            loss.backward()  # Gradients are all-reduced automatically
            optimizer.step()

        if rank == 0:
            print(f"Epoch {epoch}: complete")

    cleanup()

Launch with torchrun:

torchrun --nproc_per_node=4 train_script.py

torchrun sets RANK, LOCAL_RANK, WORLD_SIZE, and MASTER_ADDR/MASTER_PORT automatically. Your script reads them from environment variables instead of hardcoding.

The DistributedSampler Contract

DistributedSampler is essential and easy to misuse. Key rules:

sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank,
                              shuffle=True, drop_last=True)
  • Call sampler.set_epoch(epoch) every epoch. Without this, every epoch uses the same shuffle order, degrading convergence
  • Use drop_last=True when batch sizes must be equal across GPUs (required for synchronized all-reduce)
  • Don’t use shuffle=True in DataLoader when using DistributedSampler — the sampler handles shuffling

FullyShardedDataParallel (FSDP)

FSDP shards model parameters, gradients, and optimizer states across GPUs. Each GPU stores only 1/N of the model. During forward/backward, parameters are gathered just-in-time and released after use.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial

# Define wrapping policy — shard at transformer block level
wrap_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerBlock},
)

# Mixed precision config
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

model = MyTransformer().to(rank)
fsdp_model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mp_policy,
    auto_wrap_policy=wrap_policy,
    device_id=rank,
)

Sharding strategies control the memory-communication tradeoff:

StrategyMemory SavingsCommunication
FULL_SHARDMaximum (parameters + gradients + optimizer)Higher — gather before each forward/backward
SHARD_GRAD_OPModerate (gradients + optimizer only)Lower — parameters stay replicated
HYBRID_SHARDPer-node full shard, cross-node replicateBalanced for multi-node
NO_SHARDNone (equivalent to DDP)Minimum

Multi-Node Training

For training across multiple machines, the configuration requires a network-accessible master:

# Node 0 (master)
torchrun --nnodes=2 --nproc_per_node=8 \
    --node_rank=0 --master_addr=10.0.0.1 --master_port=29500 \
    train_script.py

# Node 1
torchrun --nnodes=2 --nproc_per_node=8 \
    --node_rank=1 --master_addr=10.0.0.1 --master_port=29500 \
    train_script.py

Network bandwidth becomes critical. NCCL uses GPUDirect RDMA when available (InfiniBand), falling back to TCP sockets. On cloud instances, use compute-optimized instances with high-bandwidth interconnects (AWS EFA, GCP GPUDirect).

Gradient Accumulation with DDP

Simulate larger batch sizes without more GPUs by accumulating gradients:

accumulation_steps = 4

for batch_idx, (data, target) in enumerate(loader):
    # Disable gradient sync for accumulation steps
    context = ddp_model.no_sync() if (batch_idx + 1) % accumulation_steps != 0 \
              else nullcontext()

    with context:
        output = ddp_model(data)
        loss = criterion(output, target) / accumulation_steps
        loss.backward()

    if (batch_idx + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

The no_sync() context manager skips all-reduce during accumulation steps, reducing communication by accumulation_steps×.

Checkpointing and Fault Tolerance

In long-running distributed training, failures are expected. Save checkpoints that can restart from any point:

def save_checkpoint(model, optimizer, epoch, path):
    if dist.get_rank() == 0:  # Only rank 0 saves
        state = {
            "model": model.module.state_dict(),  # .module unwraps DDP
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
        }
        torch.save(state, path)
    dist.barrier()  # Wait for save to complete

def load_checkpoint(model, optimizer, path, device):
    map_location = {"cuda:0": f"cuda:{dist.get_rank()}"}
    state = torch.load(path, map_location=map_location)
    model.module.load_state_dict(state["model"])
    optimizer.load_state_dict(state["optimizer"])
    return state["epoch"]

For FSDP, use StateDictType.FULL_STATE_DICT or SHARDED_STATE_DICT:

from torch.distributed.fsdp import StateDictType

with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
    state_dict = fsdp_model.state_dict()
    if rank == 0:
        torch.save(state_dict, "checkpoint.pt")

Profiling Distributed Performance

Identify communication bottlenecks:

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./dist_profile"),
) as prof:
    for step, (data, target) in enumerate(loader):
        if step >= 5:
            break
        output = ddp_model(data.to(rank))
        loss = criterion(output, target.to(rank))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        prof.step()

Look for: NCCL all-reduce taking >30% of step time (indicates communication bottleneck), uneven GPU utilization (indicates load imbalance), and long idle periods (pipeline bubbles).

Common Pitfalls

Deadlocks from conditional execution. If only some ranks execute a collective operation (like all-reduce), the others hang forever. Ensure all ranks follow identical control flow for distributed operations.

Forgetting sampler.set_epoch(). Silently degrades convergence — every epoch sees identical data order, effectively reducing dataset diversity.

Mixed requires_grad across ranks. If one rank freezes a parameter that another doesn’t, the all-reduce sizes mismatch and training crashes. Ensure identical model architecture and frozen parameters across all ranks.

The one thing to remember: Distributed training is reliable when you treat it as a systems engineering problem — deterministic samplers, fault-tolerant checkpointing, and communication-aware batching matter as much as the model architecture itself.

pythonmachine-learningpytorch

See Also

  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
  • Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'