Mixed Precision Training
Modern GPUs have specialized hardware (Tensor Cores) that compute matrix multiplications 2-8x faster in reduced precision (FP16, BF16, FP8) than in FP32. Mixed precision training exploits this by running the forward and backward passes in reduced precision while keeping a master copy of the weights in FP32 for accurate gradient accumulation. This gives 1.5-2x speedup and nearly 2x memory savings with virtually no accuracy loss. Mixed precision should be the default for all GPU training.
Number Formats
The precision of a floating-point number is determined by its bit allocation between exponent (range) and mantissa (precision):
| Format | Bits | Exponent | Mantissa | Dynamic Range | Precision (decimal digits) | Primary Use |
|---|---|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | ~7.2 | Master weights, optimizer states | |
| TF32 | 19 | 8 | 10 | ~3.3 | Automatic on Ampere+ for matmul | |
| FP16 | 16 | 5 | 10 | ~3.3 | Mixed precision (older GPUs) | |
| BF16 | 16 | 8 | 7 | ~2.4 | Mixed precision (Ampere+) | |
| FP8 E4M3 | 8 | 4 | 3 | ~1.0 | Forward pass (Hopper+) | |
| FP8 E5M2 | 8 | 5 | 2 | ~0.6 | Backward pass (Hopper+) |
FP32: S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM (1 + 8 + 23 = 32 bits)
BF16: S EEEEEEEE MMMMMMM (1 + 8 + 7 = 16 bits)
FP16: S EEEEE MMMMMMMMMM (1 + 5 + 10 = 16 bits)
^^^^^^^^ ^^^^^^^^^^
Same exponent as FP32 More mantissa bits than BF16
torch.amp (Automatic Mixed Precision)
PyTorch's autocast context manager automatically selects the precision for each operation. Operations that benefit from Tensor Cores (matmul, conv) run in reduced precision; operations that need accuracy (softmax, layernorm, loss) run in FP32:
from torch.amp import autocast
# BF16 (preferred on Ampere+, no loss scaling needed)
for data, target in loader:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(data)
loss = criterion(output, target)
# Backward runs outside autocast but uses BF16 for gradients
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
from torch.amp import autocast, GradScaler
# FP16 (requires GradScaler for loss scaling)
scaler = GradScaler()
for data, target in loader:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward() # Scale loss to prevent gradient underflow
scaler.unscale_(optimizer) # Unscale gradients for clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer) # Step only if no inf/nan gradients
scaler.update() # Adjust scale factor
Loss Scaling (FP16 Only)
FP16 has limited dynamic range. Small gradients () underflow to zero, causing training to stall or diverge. Loss scaling multiplies the loss by a large factor before backward, effectively shifting the gradient distribution into FP16's representable range:
Forward pass:
Weights (FP32 master) ──cast──> FP16/BF16
Input (FP16/BF16) ──matmul──> Activations (FP16/BF16) [Tensor Cores]
Activations ──softmax/norm──> Output (FP32) [precision-sensitive]
Output ──loss──> Loss (FP32)
Loss scaling (FP16 only):
Loss (FP32) ──multiply by scale──> Scaled loss (FP32)
Backward pass:
Scaled loss ──backward──> Scaled gradients (FP16) [Tensor Cores]
Unscale and update:
Scaled gradients ──divide by scale──> Gradients (FP32)
If inf/nan detected: skip update, reduce scale factor
If no overflow: update FP32 master weights, increase scale factor
This means the first few steps may have skipped updates while the scaler finds the right range. Monitor scaler.get_scale() -- if it keeps decreasing, the model has numerical stability issues independent of precision.
What Runs in Which Precision
autocast maintains an internal list of operations and their target precision. Understanding this prevents debugging surprises:
| Operation | Precision under autocast | Reason |
|---|---|---|
torch.mm, torch.matmul | FP16/BF16 | Compute-bound, Tensor Cores give 2-8x speedup |
torch.nn.Linear | FP16/BF16 | Internally a matmul |
torch.nn.Conv1d/2d/3d | FP16/BF16 | Compute-bound, Tensor Cores |
| Multi-head attention | FP16/BF16 | Q@K^T and attn@V are matmuls |
torch.nn.LayerNorm | FP32 | Reduction over features, variance needs precision |
torch.nn.BatchNorm | FP32 | Running mean/variance accumulation |
torch.softmax | FP32 | Exponentiation can overflow in FP16 |
torch.nn.CrossEntropyLoss | FP32 | Log-sum-exp needs precision |
torch.log, torch.exp | FP32 | Transcendental functions, overflow risk |
| Optimizer step | FP32 | Small updates () would vanish in FP16 |
| Master weights | FP32 | Accumulation of many small updates needs full precision |
Performance and Memory Impact
| Metric | FP32 Baseline | Mixed Precision (BF16) | Improvement |
|---|---|---|---|
| Compute-bound ops (matmul, conv) | 1x | 2-8x (Tensor Cores) | Depends on GPU generation |
| Memory-bound ops (activation, elementwise) | 1x | ~2x (half the bytes transferred) | Memory bandwidth limited |
| Parameter memory | 4 bytes/param | 2 bytes/param (+ 4 bytes master) | ~1.3x reduction |
| Activation memory | 4 bytes/element | 2 bytes/element | 2x reduction |
| Gradient memory | 4 bytes/param | 2 bytes/param | 2x reduction |
| Optimizer state (AdamW) | 8 bytes/param (m + v) | 8 bytes/param (still FP32) | No change |
| Net training throughput | 1x | 1.5-2x | Typical end-to-end |
| Net GPU memory | 1x | ~0.6-0.7x | Allows larger batch or model |
| GPU | FP32 (TFLOPS) | TF32 (TFLOPS) | FP16/BF16 (TFLOPS) | FP8 (TFLOPS) | INT8 (TOPS) |
|---|---|---|---|---|---|
| V100 | 15.7 | N/A | 125 | N/A | N/A |
| A100 | 19.5 | 156 | 312 | N/A | 624 |
| H100 | 66.9 | 495 | 990 | 1979 | 1979 |
For new training code, start with BF16 on Ampere+ GPUs, or FP16 with GradScaler on older hardware.
FP8 Training (Hopper+)
NVIDIA H100 GPUs introduce FP8 Tensor Cores with 2x throughput over FP16. FP8 training uses two formats:
# FP8 training with NVIDIA Transformer Engine
import transformer_engine.pytorch as te
# Replace nn.Linear with te.Linear for FP8 support
model = MyModel()
# te.Linear automatically handles FP8 casting with per-tensor scaling
model.ffn = te.Linear(hidden_dim, 4 * hidden_dim)
# Training loop uses fp8_autocast
with te.fp8_autocast(enabled=True):
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# FP8 uses per-tensor scaling factors (not loss scaling):
# E4M3 (4-bit exponent, 3-bit mantissa): forward activations and weights
# -> Higher precision (8 values), smaller range (+-240)
# E5M2 (5-bit exponent, 2-bit mantissa): gradients in backward
# -> Lower precision (4 values), larger range (+-57344)