Skip to main content

Distributed Training

Training large models on a single GPU is either too slow (days instead of hours) or impossible (the model does not fit in memory). Distributed training solves both problems by splitting work across multiple GPUs. The challenge is doing this efficiently -- naively partitioning work introduces communication overhead that can negate the benefit of extra GPUs. This chapter covers the four main strategies (data parallelism, tensor parallelism, pipeline parallelism, and fully sharded data parallelism), the communication patterns they use, and the practical considerations that determine which approach to use.

Data Parallelism (DDP)

DistributedDataParallel (DDP) is the simplest and most commonly used strategy. Each GPU holds a complete copy of the model and processes a different shard of the data. After each backward pass, gradients are all-reduced across GPUs so every replica updates identically:


import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

def train(rank, world_size):
setup(rank, world_size)

model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
loader = DataLoader(dataset, sampler=sampler, batch_size=32)

for epoch in range(10):
sampler.set_epoch(epoch) # Reshuffle data across GPUs each epoch
for data, target in loader:
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
loss = model(data, target)
loss.backward() # Gradients are all-reduced automatically
optimizer.step()

dist.destroy_process_group()

# Launch: torchrun --nproc_per_node=4 train.py
**How DDP overlaps communication with computation.** DDP does not wait until all gradients are computed before communicating. It groups parameters into **buckets** (default 25 MB) and begins all-reducing each bucket as soon as all gradients in that bucket are computed during backward. Because backward computes gradients in reverse order (last layer first), the first bucket communicated contains the last layer's gradients -- which are typically computed first. By the time the optimizer runs, all all-reduce operations are complete.
MistakeSymptomFix
Forgetting sampler.set_epoch(epoch)All GPUs see same data order every epochCall it before each epoch
Logging/saving from all ranksDuplicate checkpoints, garbled logsGuard with if rank == 0:
Different batch sizes per rankTraining hangs (collective mismatch)Use DistributedSampler (handles uneven data)
Unused parameters in forwardHangs during backward (unused grad bucket never fires)Set find_unused_parameters=True (slower) or fix the model
Not synchronizing BatchNormEach GPU normalizes over local batch onlyUse torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Scaling learning rateUnderfitting with more GPUsLinear scaling rule: multiply LR by world_size (with warmup)
**The linear scaling rule.** With DDP, each GPU processes `batch_size` samples, so the effective global batch size is `batch_size * world_size`. To maintain the same training dynamics, scale the learning rate proportionally: `lr = base_lr * world_size`. Always combine this with a warmup period (typically 1-5% of training) to stabilize early training with the larger learning rate.

# Gradient accumulation is essential when GPU memory limits batch size.
# Effective batch = batch_size_per_gpu * accumulation_steps * world_size
accumulation_steps = 8

for step, (data, target) in enumerate(loader):
data, target = data.to(rank), target.to(rank)

# DDP syncs gradients on every backward() by default.
# Skip sync for intermediate accumulation steps (saves communication).
context = model.no_sync() if (step + 1) % accumulation_steps != 0 else nullcontext()

with context:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
loss = model(data, target) / accumulation_steps # Scale loss
loss.backward() # Gradients accumulate (not zeroed yet)

if (step + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad() # Zero only after accumulation is complete
**Why `model.no_sync()` matters.** Without it, DDP all-reduces gradients on every `backward()` call -- even during intermediate accumulation steps where the gradients will be added to, not used. With 8 accumulation steps, this wastes 7 all-reduce operations per optimizer step. `model.no_sync()` disables the gradient hook for intermediate steps, performing a single all-reduce when the final accumulated gradient is ready. This can reduce communication overhead by up to $(\text{accumulation\_steps} - 1) / \text{accumulation\_steps}$ (87.5% for 8 steps).

Model Parallelism

When a model does not fit in a single GPU's memory, you must split the model itself across GPUs. There are three complementary strategies:

StrategyWhat Is SplitCommunication PatternMemory SavingsCompute EfficiencyWhen to Use
Data parallel (DDP)Data batchesAllReduce gradientsNone (full copy)Near-linear scalingModel fits in 1 GPU
Tensor parallel (TP)Weight matrices (columns/rows)AllReduce or AllGather activations~N×\times (N = TP degree)High (if interconnect is fast)Large layers, NVLink nodes
Pipeline parallel (PP)Model layersPoint-to-point activations~N×\times (N = PP stages)Lower (pipeline bubbles)Many layers, across nodes
Sequence parallel (SP)Sequence dimensionAllGather/ReduceScatterActivation memory reductionHighLong sequences, with TP

# Tensor parallelism splits weight matrices across GPUs.
# For a Linear layer Y = XW + b:

# Column parallel: split W along columns
# GPU 0: Y_0 = X @ W[:, :d//2] (computes first half of output features)
# GPU 1: Y_1 = X @ W[:, d//2:] (computes second half of output features)
# -> AllGather Y = [Y_0 | Y_1] to reconstruct full output

# Row parallel: split W along rows
# GPU 0: Y_0 = X_0 @ W[:d//2, :] (each GPU has partial input)
# GPU 1: Y_1 = X_1 @ W[d//2:, :]
# -> AllReduce Y = Y_0 + Y_1 to combine partial sums

# In practice, column-parallel followed by row-parallel avoids
# an AllGather between layers (Megatron-LM pattern)

# Pipeline parallelism: split model into stages across GPUs
from torch.distributed.pipelining import SplitPoint, pipeline

pipe = pipeline(
model,
mb_args=(input,),
split_spec={
"layer12": SplitPoint.BEGINNING, # Stage 0: layers 0-11 (GPU 0)
"layer24": SplitPoint.BEGINNING, # Stage 1: layers 12-23 (GPU 1)
}, # Stage 2: layers 24+ (GPU 2)
)
**Pipeline bubbles.** Naive pipeline parallelism wastes $(P-1)/P$ of each GPU's time, where $P$ is the number of pipeline stages, because each GPU must wait for the previous stage's output. **Micro-batching** (splitting each mini-batch into $M$ micro-batches and pipelining them) reduces the bubble to $(P-1)/(P-1+M)$. With $M \gg P$, the bubble becomes negligible. Advanced schedules (1F1B, interleaved 1F1B, zero-bubble) further reduce idle time.

FSDP (Fully Sharded Data Parallel)

FSDP extends DDP by sharding not just gradients but also parameters and optimizer states across GPUs. Each GPU stores only a shard of the model. Before each layer's forward/backward, the full parameters are gathered; after, they are discarded:


from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy

mp_policy = MixedPrecision(
param_dtype=torch.bfloat16, # Parameters gathered in BF16
reduce_dtype=torch.bfloat16, # Gradient reduction in BF16
buffer_dtype=torch.bfloat16, # Buffers (e.g., BatchNorm running stats)
)

model = FSDP(
model,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD, # Shard everything
use_orig_params=True, # Required for torch.compile compatibility
# auto_wrap_policy: controls which submodules get their own FSDP wrapper
# Each wrapped module gathers/discards parameters independently
)
ComponentDDP (1 GPU)FSDP (4 GPUs)FSDP (8 GPUs)
Parameters28 GB7 GB3.5 GB
Gradients28 GB7 GB3.5 GB
Optimizer states (AdamW: m, v)56 GB14 GB7 GB
Total per GPU112 GB28 GB14 GB
Communication overheadAllReduce gradientsAllGather params (2x per layer: fwd + bwd)Same, more shards
**When to use FSDP vs DDP.** Use DDP when the model fits comfortably in one GPU's memory (including optimizer states and activations). Switch to FSDP when you run out of memory. FSDP introduces additional AllGather communication (parameters must be gathered before each forward and backward), so it is slower than DDP for models that fit in memory. The communication overhead is typically 5-15% compared to DDP.

DeepSpeed ZeRO

DeepSpeed's ZeRO (Zero Redundancy Optimizer) provides three progressive stages of memory optimization, equivalent to FSDP but with a different API and additional features:

StageWhat Is ShardedMemory per GPUCommunication
ZeRO-1Optimizer states only~4×\times reductionSame as DDP (AllReduce gradients)
ZeRO-2Optimizer states + gradients~8×\times reductionReduceScatter gradients
ZeRO-3Optimizer states + gradients + parameters~N×\times reductionAllGather params (like FSDP)
ZeRO-Offload+ CPU/NVMe offloadingTrain 10×\times larger modelsCPU-GPU transfer overhead

# DeepSpeed configuration (ds_config.json)
{
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 3,
"offload_param": {"device": "cpu", "pin_memory": true},
"offload_optimizer": {"device": "cpu", "pin_memory": true},
"overlap_comm": true,
"contiguous_gradients": true,
"prefetch_bucket_size": 5e7,
"param_persistence_threshold": 1e5
},
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8,
"gradient_clipping": 1.0
}
**FSDP vs DeepSpeed ZeRO-3.** Both achieve the same memory savings through parameter sharding. Key differences: - **FSDP** is native PyTorch (no external dependency), better integration with `torch.compile`, and is the recommended approach for new projects. - **DeepSpeed** offers CPU/NVMe offloading (ZeRO-Offload/Infinity), more mature for very large models (100B+), and has additional features like sparse attention and compression. - For models up to ~13B on 8 GPUs, FSDP is sufficient. For 70B+ models, combine FSDP or DeepSpeed with tensor parallelism.

3D Parallelism

For the largest models (70B+), you combine all three strategies:


Example: 70B model on 64 GPUs (8 nodes, 8 GPUs each)

Tensor parallel (TP=8): Split each layer across 8 GPUs within a node (NVLink)
Pipeline parallel (PP=4): Split 80 layers into 4 stages across 4 groups
Data parallel (DP=2): 2 replicas of the full pipeline, each on 32 GPUs

Total: TP(8) x PP(4) x DP(2) = 64 GPUs

Communication:
- TP: AllReduce within node (NVLink, ~900 GB/s) -> fast
- PP: Point-to-point between nodes (InfiniBand, ~400 Gb/s) -> moderate
- DP: AllReduce across DP replicas (InfiniBand) -> infrequent (once per step)
Model SizeGPUs AvailableRecommended Strategy
< 1B1-8DDP
1-13B4-8FSDP or ZeRO-3
13-70B8-64FSDP + TP (within node)
70B+64-512TP + PP + DP (3D parallelism)
200B+512+3D parallelism + sequence parallelism + expert parallelism