Distributed Training
Training large models on a single GPU is either too slow (days instead of hours) or impossible (the model does not fit in memory). Distributed training solves both problems by splitting work across multiple GPUs. The challenge is doing this efficiently -- naively partitioning work introduces communication overhead that can negate the benefit of extra GPUs. This chapter covers the four main strategies (data parallelism, tensor parallelism, pipeline parallelism, and fully sharded data parallelism), the communication patterns they use, and the practical considerations that determine which approach to use.
Data Parallelism (DDP)
DistributedDataParallel (DDP) is the simplest and most commonly used strategy. Each GPU holds a complete copy of the model and processes a different shard of the data. After each backward pass, gradients are all-reduced across GPUs so every replica updates identically:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train(rank, world_size):
setup(rank, world_size)
model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
loader = DataLoader(dataset, sampler=sampler, batch_size=32)
for epoch in range(10):
sampler.set_epoch(epoch) # Reshuffle data across GPUs each epoch
for data, target in loader:
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
loss = model(data, target)
loss.backward() # Gradients are all-reduced automatically
optimizer.step()
dist.destroy_process_group()
# Launch: torchrun --nproc_per_node=4 train.py
| Mistake | Symptom | Fix |
|---|---|---|
Forgetting sampler.set_epoch(epoch) | All GPUs see same data order every epoch | Call it before each epoch |
| Logging/saving from all ranks | Duplicate checkpoints, garbled logs | Guard with if rank == 0: |
| Different batch sizes per rank | Training hangs (collective mismatch) | Use DistributedSampler (handles uneven data) |
| Unused parameters in forward | Hangs during backward (unused grad bucket never fires) | Set find_unused_parameters=True (slower) or fix the model |
| Not synchronizing BatchNorm | Each GPU normalizes over local batch only | Use torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| Scaling learning rate | Underfitting with more GPUs | Linear scaling rule: multiply LR by world_size (with warmup) |
# Gradient accumulation is essential when GPU memory limits batch size.
# Effective batch = batch_size_per_gpu * accumulation_steps * world_size
accumulation_steps = 8
for step, (data, target) in enumerate(loader):
data, target = data.to(rank), target.to(rank)
# DDP syncs gradients on every backward() by default.
# Skip sync for intermediate accumulation steps (saves communication).
context = model.no_sync() if (step + 1) % accumulation_steps != 0 else nullcontext()
with context:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
loss = model(data, target) / accumulation_steps # Scale loss
loss.backward() # Gradients accumulate (not zeroed yet)
if (step + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad() # Zero only after accumulation is complete
Model Parallelism
When a model does not fit in a single GPU's memory, you must split the model itself across GPUs. There are three complementary strategies:
| Strategy | What Is Split | Communication Pattern | Memory Savings | Compute Efficiency | When to Use |
|---|---|---|---|---|---|
| Data parallel (DDP) | Data batches | AllReduce gradients | None (full copy) | Near-linear scaling | Model fits in 1 GPU |
| Tensor parallel (TP) | Weight matrices (columns/rows) | AllReduce or AllGather activations | ~N (N = TP degree) | High (if interconnect is fast) | Large layers, NVLink nodes |
| Pipeline parallel (PP) | Model layers | Point-to-point activations | ~N (N = PP stages) | Lower (pipeline bubbles) | Many layers, across nodes |
| Sequence parallel (SP) | Sequence dimension | AllGather/ReduceScatter | Activation memory reduction | High | Long sequences, with TP |
# Tensor parallelism splits weight matrices across GPUs.
# For a Linear layer Y = XW + b:
# Column parallel: split W along columns
# GPU 0: Y_0 = X @ W[:, :d//2] (computes first half of output features)
# GPU 1: Y_1 = X @ W[:, d//2:] (computes second half of output features)
# -> AllGather Y = [Y_0 | Y_1] to reconstruct full output
# Row parallel: split W along rows
# GPU 0: Y_0 = X_0 @ W[:d//2, :] (each GPU has partial input)
# GPU 1: Y_1 = X_1 @ W[d//2:, :]
# -> AllReduce Y = Y_0 + Y_1 to combine partial sums
# In practice, column-parallel followed by row-parallel avoids
# an AllGather between layers (Megatron-LM pattern)
# Pipeline parallelism: split model into stages across GPUs
from torch.distributed.pipelining import SplitPoint, pipeline
pipe = pipeline(
model,
mb_args=(input,),
split_spec={
"layer12": SplitPoint.BEGINNING, # Stage 0: layers 0-11 (GPU 0)
"layer24": SplitPoint.BEGINNING, # Stage 1: layers 12-23 (GPU 1)
}, # Stage 2: layers 24+ (GPU 2)
)
FSDP (Fully Sharded Data Parallel)
FSDP extends DDP by sharding not just gradients but also parameters and optimizer states across GPUs. Each GPU stores only a shard of the model. Before each layer's forward/backward, the full parameters are gathered; after, they are discarded:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16, # Parameters gathered in BF16
reduce_dtype=torch.bfloat16, # Gradient reduction in BF16
buffer_dtype=torch.bfloat16, # Buffers (e.g., BatchNorm running stats)
)
model = FSDP(
model,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD, # Shard everything
use_orig_params=True, # Required for torch.compile compatibility
# auto_wrap_policy: controls which submodules get their own FSDP wrapper
# Each wrapped module gathers/discards parameters independently
)
| Component | DDP (1 GPU) | FSDP (4 GPUs) | FSDP (8 GPUs) |
|---|---|---|---|
| Parameters | 28 GB | 7 GB | 3.5 GB |
| Gradients | 28 GB | 7 GB | 3.5 GB |
| Optimizer states (AdamW: m, v) | 56 GB | 14 GB | 7 GB |
| Total per GPU | 112 GB | 28 GB | 14 GB |
| Communication overhead | AllReduce gradients | AllGather params (2x per layer: fwd + bwd) | Same, more shards |
DeepSpeed ZeRO
DeepSpeed's ZeRO (Zero Redundancy Optimizer) provides three progressive stages of memory optimization, equivalent to FSDP but with a different API and additional features:
| Stage | What Is Sharded | Memory per GPU | Communication |
|---|---|---|---|
| ZeRO-1 | Optimizer states only | ~4 reduction | Same as DDP (AllReduce gradients) |
| ZeRO-2 | Optimizer states + gradients | ~8 reduction | ReduceScatter gradients |
| ZeRO-3 | Optimizer states + gradients + parameters | ~N reduction | AllGather params (like FSDP) |
| ZeRO-Offload | + CPU/NVMe offloading | Train 10 larger models | CPU-GPU transfer overhead |
# DeepSpeed configuration (ds_config.json)
{
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 3,
"offload_param": {"device": "cpu", "pin_memory": true},
"offload_optimizer": {"device": "cpu", "pin_memory": true},
"overlap_comm": true,
"contiguous_gradients": true,
"prefetch_bucket_size": 5e7,
"param_persistence_threshold": 1e5
},
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8,
"gradient_clipping": 1.0
}
3D Parallelism
For the largest models (70B+), you combine all three strategies:
Example: 70B model on 64 GPUs (8 nodes, 8 GPUs each)
Tensor parallel (TP=8): Split each layer across 8 GPUs within a node (NVLink)
Pipeline parallel (PP=4): Split 80 layers into 4 stages across 4 groups
Data parallel (DP=2): 2 replicas of the full pipeline, each on 32 GPUs
Total: TP(8) x PP(4) x DP(2) = 64 GPUs
Communication:
- TP: AllReduce within node (NVLink, ~900 GB/s) -> fast
- PP: Point-to-point between nodes (InfiniBand, ~400 Gb/s) -> moderate
- DP: AllReduce across DP replicas (InfiniBand) -> infrequent (once per step)
| Model Size | GPUs Available | Recommended Strategy |
|---|---|---|
| < 1B | 1-8 | DDP |
| 1-13B | 4-8 | FSDP or ZeRO-3 |
| 13-70B | 8-64 | FSDP + TP (within node) |
| 70B+ | 64-512 | TP + PP + DP (3D parallelism) |
| 200B+ | 512+ | 3D parallelism + sequence parallelism + expert parallelism |