PyTorch Distributed Training — Core Concepts

Why Single-GPU Training Hits a Wall

Modern deep learning models have billions of parameters. Training them means processing terabytes of data through billions of floating-point operations per sample, for hundreds of thousands of steps. A single NVIDIA A100 GPU can do about 312 TFLOPS in FP16 — impressive, but training a 70B-parameter model at that rate would take over a year. Distributed training is not optional for state-of-the-art AI; it’s required.

Data Parallelism: The Workhorse

Data parallelism is the most common distributed strategy. The idea:

  1. Replicate the full model on every GPU
  2. Partition each training batch so each GPU processes a different slice
  3. Compute forward and backward passes independently on each GPU
  4. Synchronize gradients across all GPUs (via all-reduce)
  5. Update weights identically on every GPU

After synchronization, all GPUs have the same model weights. The effective batch size scales linearly: 8 GPUs with batch size 32 each = effective batch size 256.

Model Parallelism: When the Model Won’t Fit

Some models are too large for a single GPU’s memory. Model parallelism splits the model itself:

  • Tensor parallelism: Splits individual layers across GPUs. A large matrix multiplication is divided so each GPU computes part of the result
  • Pipeline parallelism: Assigns different layers to different GPUs. Data flows through GPU 1’s layers, then GPU 2’s layers, like an assembly line

Pipeline parallelism introduces “bubble time” — GPUs waiting for their input. Techniques like micro-batching (splitting each batch into smaller chunks) reduce this idle time.

Communication: The Hidden Cost

Distributed training’s bottleneck is usually communication, not computation. The main patterns:

OperationWhat It DoesWhen It’s Used
All-reduceEvery GPU sends gradients, every GPU receives the averageData parallelism gradient sync
All-gatherEvery GPU shares data, every GPU gets the complete setCollecting outputs from all GPUs
BroadcastOne GPU sends data to all othersDistributing initial weights

For data parallelism, all-reduce is the critical operation. With N GPUs and M parameters, each GPU sends and receives O(M) data. The NCCL library (NVIDIA Collective Communication Library) optimizes this using ring-based algorithms that overlap communication with computation.

Scaling Efficiency

Perfect linear scaling means 8 GPUs train 8× faster. In practice:

  • 2-4 GPUs on one machine: 90-95% efficiency (fast NVLink interconnect)
  • 8 GPUs on one machine: 85-90% efficiency
  • Multi-node (across machines): 70-85% efficiency (limited by network bandwidth)

Efficiency drops because communication overhead grows with GPU count. Techniques to improve scaling:

  • Gradient compression: Send approximations instead of full gradients
  • Overlap communication with computation: Start all-reduce for early layers while later layers are still computing
  • Large batch training: More compute per communication round

Common Misconception

People think adding more GPUs always makes training faster proportionally. In reality, there’s a batch size ceiling — beyond a certain effective batch size, model quality degrades. This is called the “critical batch size.” For example, BERT trains well up to batch size 8192 but degrades beyond that. Adding more GPUs past this point wastes resources.

The learning rate also needs adjustment. The linear scaling rule says: if you multiply batch size by K, multiply learning rate by K. But this breaks down at large scales, requiring warmup schedules and careful tuning.

PyTorch’s Distributed Toolkit

PyTorch provides several levels of abstraction:

  • DistributedDataParallel (DDP): The go-to for data parallelism. Handles gradient synchronization automatically
  • FullyShardedDataParallel (FSDP): Shards model parameters across GPUs, reducing memory per GPU. Developed from Facebook’s FairScale
  • torch.distributed module: Low-level primitives (send, recv, all-reduce) for custom communication patterns
  • torchrun: Launch utility that sets up process groups and environment variables

Choosing Your Strategy

SituationRecommended Approach
Model fits on one GPU, want faster trainingDDP (data parallelism)
Model barely fits on one GPUFSDP or gradient checkpointing + DDP
Model doesn’t fit on one GPUFSDP, or pipeline + tensor parallelism
Training on 100+ GPUsFSDP with hybrid sharding

The one thing to remember: Distributed training is fundamentally about balancing computation and communication — more GPUs mean more compute power but also more coordination overhead, and the art is maximizing the former while minimizing the latter.

pythonmachine-learningpytorch

See Also

  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
  • Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'