Skip to main content

Mixed Precision Training

Modern GPUs have specialized hardware (Tensor Cores) that compute matrix multiplications 2-8x faster in reduced precision (FP16, BF16, FP8) than in FP32. Mixed precision training exploits this by running the forward and backward passes in reduced precision while keeping a master copy of the weights in FP32 for accurate gradient accumulation. This gives 1.5-2x speedup and nearly 2x memory savings with virtually no accuracy loss. Mixed precision should be the default for all GPU training.

Number Formats

The precision of a floating-point number is determined by its bit allocation between exponent (range) and mantissa (precision):

FormatBitsExponentMantissaDynamic RangePrecision (decimal digits)Primary Use
FP3232823±3.4×1038\pm 3.4 \times 10^{38}~7.2Master weights, optimizer states
TF3219810±3.4×1038\pm 3.4 \times 10^{38}~3.3Automatic on Ampere+ for matmul
FP1616510±6.5×104\pm 6.5 \times 10^{4}~3.3Mixed precision (older GPUs)
BF161687±3.4×1038\pm 3.4 \times 10^{38}~2.4Mixed precision (Ampere+)
FP8 E4M3843±240\pm 240~1.0Forward pass (Hopper+)
FP8 E5M2852±57344\pm 57344~0.6Backward pass (Hopper+)
**BF16 vs FP16: why BF16 is preferred for training.** BF16 has the same 8-bit exponent as FP32, giving it the same dynamic range ($10^{\pm 38}$). This means gradients that are representable in FP32 are almost always representable in BF16 (with less precision, but without underflow). FP16 has only a 5-bit exponent, so gradients smaller than $6 \times 10^{-8}$ underflow to zero -- this is common in deep networks with many layers. FP16 requires loss scaling to work around this; BF16 does not.
FP32: S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM (1 + 8 + 23 = 32 bits)
BF16: S EEEEEEEE MMMMMMM (1 + 8 + 7 = 16 bits)
FP16: S EEEEE MMMMMMMMMM (1 + 5 + 10 = 16 bits)
^^^^^^^^ ^^^^^^^^^^
Same exponent as FP32 More mantissa bits than BF16

torch.amp (Automatic Mixed Precision)

PyTorch's autocast context manager automatically selects the precision for each operation. Operations that benefit from Tensor Cores (matmul, conv) run in reduced precision; operations that need accuracy (softmax, layernorm, loss) run in FP32:


from torch.amp import autocast

# BF16 (preferred on Ampere+, no loss scaling needed)
for data, target in loader:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(data)
loss = criterion(output, target)
# Backward runs outside autocast but uses BF16 for gradients
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

from torch.amp import autocast, GradScaler

# FP16 (requires GradScaler for loss scaling)
scaler = GradScaler()
for data, target in loader:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(data)
loss = criterion(output, target)

scaler.scale(loss).backward() # Scale loss to prevent gradient underflow
scaler.unscale_(optimizer) # Unscale gradients for clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer) # Step only if no inf/nan gradients
scaler.update() # Adjust scale factor

Loss Scaling (FP16 Only)

FP16 has limited dynamic range. Small gradients (<6×108< 6 \times 10^{-8}) underflow to zero, causing training to stall or diverge. Loss scaling multiplies the loss by a large factor before backward, effectively shifting the gradient distribution into FP16's representable range:


Forward pass:
Weights (FP32 master) ──cast──> FP16/BF16
Input (FP16/BF16) ──matmul──> Activations (FP16/BF16) [Tensor Cores]
Activations ──softmax/norm──> Output (FP32) [precision-sensitive]
Output ──loss──> Loss (FP32)

Loss scaling (FP16 only):
Loss (FP32) ──multiply by scale──> Scaled loss (FP32)

Backward pass:
Scaled loss ──backward──> Scaled gradients (FP16) [Tensor Cores]

Unscale and update:
Scaled gradients ──divide by scale──> Gradients (FP32)
If inf/nan detected: skip update, reduce scale factor
If no overflow: update FP32 master weights, increase scale factor
**GradScaler dynamics.** The scaler starts with a large scale factor (default 65536) and adjusts dynamically: - Every N steps without overflow: multiply scale by 2 (more aggressive, better precision preservation) - On any step with overflow/NaN: divide scale by 2 and skip the weight update

This means the first few steps may have skipped updates while the scaler finds the right range. Monitor scaler.get_scale() -- if it keeps decreasing, the model has numerical stability issues independent of precision.

What Runs in Which Precision

autocast maintains an internal list of operations and their target precision. Understanding this prevents debugging surprises:

OperationPrecision under autocastReason
torch.mm, torch.matmulFP16/BF16Compute-bound, Tensor Cores give 2-8x speedup
torch.nn.LinearFP16/BF16Internally a matmul
torch.nn.Conv1d/2d/3dFP16/BF16Compute-bound, Tensor Cores
Multi-head attentionFP16/BF16Q@K^T and attn@V are matmuls
torch.nn.LayerNormFP32Reduction over features, variance needs precision
torch.nn.BatchNormFP32Running mean/variance accumulation
torch.softmaxFP32Exponentiation can overflow in FP16
torch.nn.CrossEntropyLossFP32Log-sum-exp needs precision
torch.log, torch.expFP32Transcendental functions, overflow risk
Optimizer stepFP32Small updates (lr×grad\text{lr} \times \text{grad}) would vanish in FP16
Master weightsFP32Accumulation of many small updates needs full precision
**TF32: the invisible mixed precision.** On Ampere+ GPUs, cuBLAS and cuDNN automatically use TF32 for FP32 matmul and convolution operations. TF32 uses FP32 range (8-bit exponent) but only 10-bit mantissa, giving 8x throughput on Tensor Cores compared to full FP32. This happens by default -- you do not need `autocast`. To disable it (for debugging): `torch.backends.cuda.matmul.allow_tf32 = False`.

Performance and Memory Impact

MetricFP32 BaselineMixed Precision (BF16)Improvement
Compute-bound ops (matmul, conv)1x2-8x (Tensor Cores)Depends on GPU generation
Memory-bound ops (activation, elementwise)1x~2x (half the bytes transferred)Memory bandwidth limited
Parameter memory4 bytes/param2 bytes/param (+ 4 bytes master)~1.3x reduction
Activation memory4 bytes/element2 bytes/element2x reduction
Gradient memory4 bytes/param2 bytes/param2x reduction
Optimizer state (AdamW)8 bytes/param (m + v)8 bytes/param (still FP32)No change
Net training throughput1x1.5-2xTypical end-to-end
Net GPU memory1x~0.6-0.7xAllows larger batch or model
GPUFP32 (TFLOPS)TF32 (TFLOPS)FP16/BF16 (TFLOPS)FP8 (TFLOPS)INT8 (TOPS)
V10015.7N/A125N/AN/A
A10019.5156312N/A624
H10066.949599019791979
**Mixed precision should be the default.** There is almost no reason to train in pure FP32 on modern GPUs. The only exceptions are: 1. **Debugging numerical issues** -- temporarily disable to isolate precision-related bugs 2. **Very small models** where Tensor Core overhead dominates the small matrix multiplications 3. **Specific operations** that are known to be numerically sensitive (usually handled by autocast's FP32 fallback list)

For new training code, start with BF16 on Ampere+ GPUs, or FP16 with GradScaler on older hardware.

FP8 Training (Hopper+)

NVIDIA H100 GPUs introduce FP8 Tensor Cores with 2x throughput over FP16. FP8 training uses two formats:


# FP8 training with NVIDIA Transformer Engine
import transformer_engine.pytorch as te

# Replace nn.Linear with te.Linear for FP8 support
model = MyModel()
# te.Linear automatically handles FP8 casting with per-tensor scaling
model.ffn = te.Linear(hidden_dim, 4 * hidden_dim)

# Training loop uses fp8_autocast
with te.fp8_autocast(enabled=True):
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()

# FP8 uses per-tensor scaling factors (not loss scaling):
# E4M3 (4-bit exponent, 3-bit mantissa): forward activations and weights
# -> Higher precision (8 values), smaller range (+-240)
# E5M2 (5-bit exponent, 2-bit mantissa): gradients in backward
# -> Lower precision (4 values), larger range (+-57344)