PyTorch Distributed Training — Core Concepts
Why Single-GPU Training Hits a Wall
Modern deep learning models have billions of parameters. Training them means processing terabytes of data through billions of floating-point operations per sample, for hundreds of thousands of steps. A single NVIDIA A100 GPU can do about 312 TFLOPS in FP16 — impressive, but training a 70B-parameter model at that rate would take over a year. Distributed training is not optional for state-of-the-art AI; it’s required.
Data Parallelism: The Workhorse
Data parallelism is the most common distributed strategy. The idea:
- Replicate the full model on every GPU
- Partition each training batch so each GPU processes a different slice
- Compute forward and backward passes independently on each GPU
- Synchronize gradients across all GPUs (via all-reduce)
- Update weights identically on every GPU
After synchronization, all GPUs have the same model weights. The effective batch size scales linearly: 8 GPUs with batch size 32 each = effective batch size 256.
Model Parallelism: When the Model Won’t Fit
Some models are too large for a single GPU’s memory. Model parallelism splits the model itself:
- Tensor parallelism: Splits individual layers across GPUs. A large matrix multiplication is divided so each GPU computes part of the result
- Pipeline parallelism: Assigns different layers to different GPUs. Data flows through GPU 1’s layers, then GPU 2’s layers, like an assembly line
Pipeline parallelism introduces “bubble time” — GPUs waiting for their input. Techniques like micro-batching (splitting each batch into smaller chunks) reduce this idle time.
Communication: The Hidden Cost
Distributed training’s bottleneck is usually communication, not computation. The main patterns:
| Operation | What It Does | When It’s Used |
|---|---|---|
| All-reduce | Every GPU sends gradients, every GPU receives the average | Data parallelism gradient sync |
| All-gather | Every GPU shares data, every GPU gets the complete set | Collecting outputs from all GPUs |
| Broadcast | One GPU sends data to all others | Distributing initial weights |
For data parallelism, all-reduce is the critical operation. With N GPUs and M parameters, each GPU sends and receives O(M) data. The NCCL library (NVIDIA Collective Communication Library) optimizes this using ring-based algorithms that overlap communication with computation.
Scaling Efficiency
Perfect linear scaling means 8 GPUs train 8× faster. In practice:
- 2-4 GPUs on one machine: 90-95% efficiency (fast NVLink interconnect)
- 8 GPUs on one machine: 85-90% efficiency
- Multi-node (across machines): 70-85% efficiency (limited by network bandwidth)
Efficiency drops because communication overhead grows with GPU count. Techniques to improve scaling:
- Gradient compression: Send approximations instead of full gradients
- Overlap communication with computation: Start all-reduce for early layers while later layers are still computing
- Large batch training: More compute per communication round
Common Misconception
People think adding more GPUs always makes training faster proportionally. In reality, there’s a batch size ceiling — beyond a certain effective batch size, model quality degrades. This is called the “critical batch size.” For example, BERT trains well up to batch size 8192 but degrades beyond that. Adding more GPUs past this point wastes resources.
The learning rate also needs adjustment. The linear scaling rule says: if you multiply batch size by K, multiply learning rate by K. But this breaks down at large scales, requiring warmup schedules and careful tuning.
PyTorch’s Distributed Toolkit
PyTorch provides several levels of abstraction:
DistributedDataParallel(DDP): The go-to for data parallelism. Handles gradient synchronization automaticallyFullyShardedDataParallel(FSDP): Shards model parameters across GPUs, reducing memory per GPU. Developed from Facebook’s FairScaletorch.distributedmodule: Low-level primitives (send, recv, all-reduce) for custom communication patternstorchrun: Launch utility that sets up process groups and environment variables
Choosing Your Strategy
| Situation | Recommended Approach |
|---|---|
| Model fits on one GPU, want faster training | DDP (data parallelism) |
| Model barely fits on one GPU | FSDP or gradient checkpointing + DDP |
| Model doesn’t fit on one GPU | FSDP, or pipeline + tensor parallelism |
| Training on 100+ GPUs | FSDP with hybrid sharding |
The one thing to remember: Distributed training is fundamentally about balancing computation and communication — more GPUs mean more compute power but also more coordination overhead, and the art is maximizing the former while minimizing the latter.
See Also
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
- Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
- Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'