Skip to main content

Training Loop

The training loop is where everything comes together: data loading, forward pass, loss computation, backward pass, and optimizer step. This chapter presents a production-quality training loop with mixed precision, gradient clipping, learning rate scheduling, and checkpointing -- and explains why each component is there.

Complete Training Loop


import torch
import torch.nn as nn
from torch.amp import GradScaler, autocast

def train(model, train_loader, val_loader, epochs, lr=3e-4):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Optimizer: AdamW is the standard for transformers
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=0.01, # L2 regularization (decoupled from LR)
betas=(0.9, 0.999), # Momentum and RMSProp coefficients
eps=1e-8, # Numerical stability
)

# Learning rate schedule: warmup + cosine decay
total_steps = len(train_loader) * epochs
warmup_steps = int(0.1 * total_steps) # 10% warmup
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=lr,
total_steps=total_steps,
pct_start=warmup_steps / total_steps,
anneal_strategy='cos',
)

# Mixed precision: GradScaler for FP16 (not needed for BF16)
use_bf16 = torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if use_bf16 else torch.float16
scaler = GradScaler(enabled=not use_bf16) # Only needed for FP16

criterion = nn.CrossEntropyLoss()
best_val_loss = float('inf')
global_step = 0

for epoch in range(epochs):
# ===== Training Phase =====
model.train()
total_loss = 0
num_batches = 0

for batch_idx, (data, target) in enumerate(train_loader):
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)

# 1. Zero gradients
optimizer.zero_grad(set_to_none=True) # set_to_none=True saves memory

# 2. Forward pass with mixed precision
with autocast(device_type='cuda', dtype=amp_dtype):
output = model(data)
loss = criterion(output, target)

# 3. Backward pass (scaled for FP16 stability)
scaler.scale(loss).backward()

# 4. Gradient clipping (unscale first, then clip, then step)
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0
)

# 5. Optimizer step (scaler skips step if gradients contain Inf/NaN)
scaler.step(optimizer)
scaler.update()

# 6. Learning rate step (per-iteration, not per-epoch)
scheduler.step()

# Logging
total_loss += loss.item()
num_batches += 1
global_step += 1

if global_step % 100 == 0:
avg_loss = total_loss / num_batches
current_lr = scheduler.get_last_lr()[0]
print(f"Step {global_step} | Loss: {avg_loss:.4f} | "
f"Grad Norm: {grad_norm:.2f} | LR: {current_lr:.2e}")

# ===== Validation Phase =====
val_loss, val_acc = evaluate(model, val_loader, criterion, device, amp_dtype)

print(f"\nEpoch {epoch+1}/{epochs} | "
f"Train Loss: {total_loss/num_batches:.4f} | "
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2%}")

# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(model, optimizer, scheduler, epoch, 'best_model.pt')

return model
**The order of operations matters.** Each step in the training loop has a specific reason:
zero_grad() # Must be FIRST: clear stale gradients from previous step
forward() # Compute predictions (inside autocast for mixed precision)
loss.backward() # Compute gradients (inside scaler.scale for FP16)
unscale_() # Convert gradients back to FP32 (for gradient clipping)
clip_grad() # Clip AFTER unscaling, BEFORE optimizer step
step() # Update weights (scaler.step checks for Inf/NaN)
scaler.update() # Adjust loss scale for next iteration
scheduler.step() # Update learning rate

Changing the order (e.g., clipping before unscaling, or scheduling before stepping) will produce incorrect results.

Evaluation


@torch.no_grad()
def evaluate(model, loader, criterion, device, amp_dtype=torch.bfloat16):
"""Evaluate model on a dataset. Returns (loss, accuracy)."""
model.eval() # CRITICAL: disables dropout and batch norm training mode
total_loss = 0
correct = 0
total = 0

for data, target in loader:
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)

with torch.autocast(device_type='cuda', dtype=amp_dtype):
output = model(data)
loss = criterion(output, target)

total_loss += loss.item() * target.size(0) # Weight by batch size
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)

model.train() # Restore training mode

return total_loss / total, correct / total
**Common evaluation mistakes:**
MistakeConsequenceFix
Forgetting model.eval()Dropout active, BN uses batch statsAdd model.eval() before validation
Forgetting torch.no_grad()Wastes memory storing activationsWrap eval in @torch.no_grad()
Averaging loss per batch (not per sample)Wrong average if last batch is smallerWeight by batch size: loss * batch_size
Not restoring model.train() after evalDropout stays off, BN uses running statsAdd model.train() after validation
Reporting train loss as validation lossMisleading metricsUse separate data loader

Mixed Precision Training

Mixed precision uses lower-precision floating point (FP16 or BF16) for most operations while keeping a master copy of weights in FP32 for numerical stability:


# ===== Option 1: BF16 (preferred on Ampere+ GPUs) =====
# No GradScaler needed -- BF16 has same exponent range as FP32
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(data)
loss = criterion(output, target)

loss.backward() # Gradients computed in BF16
optimizer.step() # Optimizer updates master FP32 weights


# ===== Option 2: FP16 (needed on V100) =====
# GradScaler prevents gradient underflow in FP16
scaler = torch.amp.GradScaler()

with torch.autocast(device_type='cuda', dtype=torch.float16):
output = model(data)
loss = criterion(output, target)

scaler.scale(loss).backward() # Scale loss to prevent gradient underflow
scaler.unscale_(optimizer) # Unscale for gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer) # Skip step if gradients contain Inf/NaN
scaler.update() # Adjust scale factor
**What runs in which precision.** `torch.autocast` does not convert everything to FP16/BF16. It maintains a list of operations and their preferred precision:
Operation TypePrecision Under autocastWhy
GEMM (linear, matmul, conv)FP16/BF16Tensor Cores give 8-16x speedup
Elementwise (ReLU, GELU)FP16/BF16Memory-bound, benefits from smaller tensors
Reduction (sum, mean, norm)FP32Accumulation needs higher precision
SoftmaxFP32Numerical stability (exp overflow)
Loss computationFP32Log, exp need precision
Batch normFP32Running stats need precision
Layer normDependsSome implementations use FP32 internally

Autocast handles this automatically -- you do not need to manually cast tensors. It inserts casts at operation boundaries.

MetricFP32Mixed Precision (BF16/FP16)Savings
Model weights memory4P4P bytes4P4P (master) + 2P2P (working) = 6P6PNone (need master copy)
Activation memory4A4A bytes2A2A bytes2x
Gradient memory4P4P bytes2P2P bytes2x
GEMM throughputPeak FP328-16x faster (Tensor Cores)8-16x
Memory bandwidthBaseline~2x (half the data to move)2x
Net training speedupBaseline1.5-3xDepends on model

Optimizers

OptimizerKey PropertyTypical UseDefault Hyperparams
SGDSimple, well-understoodCNNs, when compute budget is largelr=0.1, momentum=0.9
SGD + NesterovBetter convergence than vanilla SGDCNNs with cosine schedulelr=0.1, momentum=0.9, nesterov=True
AdamAdaptive per-parameter LRGeneral defaultlr=3e-4, betas=(0.9, 0.999)
AdamWAdam with decoupled weight decayTransformers (standard)lr=3e-4, wd=0.01
AdafactorMemory-efficient Adam (no 2nd moment)Very large models (memory-constrained)lr=1e-3
LAMBLayer-wise adaptive ratesLarge-batch traininglr=1e-3
LionSign-based optimizerRecent alternative to Adamlr=3e-4, betas=(0.9, 0.99)
**Optimizer memory overhead.** The optimizer stores state for each parameter:
OptimizerState per ParameterTotal Memory (for PP params)
SGDNone (or momentum buffer: PP)00 or 4P4P bytes
SGD + momentumMomentum buffer4P4P bytes
Adam/AdamW1st moment (mm) + 2nd moment (vv)8P8P bytes
Adam + FP32 master weightsMaster weights + mm + vv12P12P bytes
AdafactorRow + column factors (no full 2nd moment)~4P4\sqrt{P} bytes

For a 7B parameter model in FP32: Adam requires 7B * 8 = 56 GB just for optimizer states. This is why large model training uses optimizer state sharding (ZeRO/FSDP) and lower-precision optimizers.

Learning Rate Schedules


# 1. Cosine annealing (most common for transformers)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=total_epochs, eta_min=lr * 0.01
)

# 2. OneCycleLR (warmup + cosine in one)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=lr, total_steps=total_steps,
pct_start=0.1, # 10% warmup
anneal_strategy='cos',
)

# 3. Linear warmup + cosine decay (LLM standard)
from torch.optim.lr_scheduler import LambdaLR

def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps # Linear warmup
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress)) # Cosine decay

scheduler = LambdaLR(optimizer, lr_lambda)

# 4. Step decay (legacy, for CNNs)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
**Why warmup is necessary.** At initialization, the model produces random outputs and the loss landscape is highly curved. A large learning rate at this point can push parameters to extreme values, causing divergence or NaN. Warmup gradually increases the LR from near-zero to the target, allowing the model to settle into a reasonable region of parameter space first.
SettingWithout WarmupWith Warmup
Small LR (1e-5)Usually works but slowNot needed
Medium LR (3e-4)Sometimes divergesStable
Large LR (1e-3)Often divergesUsually stable
Large batch trainingFrequently divergesRequired

Common warmup lengths: 1-10% of total training steps. LLM training typically uses 1-2% warmup with thousands of total steps.

Checkpointing


def save_checkpoint(model, optimizer, scheduler, epoch, path, **kwargs):
"""Save a full training checkpoint for resuming."""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'rng_state': torch.random.get_rng_state(),
}
if torch.cuda.is_available():
checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state_all()
checkpoint.update(kwargs) # Additional metadata (loss, metrics, etc.)
torch.save(checkpoint, path)

def load_checkpoint(model, optimizer, scheduler, path, device='cuda'):
"""Load a checkpoint and restore all training state."""
checkpoint = torch.load(path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

# Restore RNG state for reproducibility
if 'rng_state' in checkpoint:
torch.random.set_rng_state(checkpoint['rng_state'])
if 'cuda_rng_state' in checkpoint and torch.cuda.is_available():
torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state'])

return checkpoint['epoch']
**Checkpointing best practices:**
  1. Save periodically (every N steps, not just every epoch). If training crashes at epoch 9 of 10, you lose 90% of compute without mid-epoch checkpoints.

  2. Keep multiple checkpoints (rotating buffer of last K checkpoints). A corrupted checkpoint with no backup is catastrophic.

  3. Include everything needed to resume. Model weights alone are not enough for exact resumption. You need optimizer state (momentum buffers), scheduler state (LR position), and RNG states (for data ordering).

  4. Async saving for large models. Checkpointing a 70B model takes 30+ seconds. Save asynchronously to avoid blocking training:

    import threading
    thread = threading.Thread(target=torch.save, args=(checkpoint, path))
    thread.start()

Gradient Clipping


# Method 1: Clip by global norm (standard for transformers)
# Scales all gradients by the same factor if total norm exceeds max_norm
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0,
)
# Returns the total norm BEFORE clipping -- log this!

# Method 2: Clip by value (less common)
# Clamps each gradient element independently
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)

# Monitoring gradient norms (essential for debugging training)
if global_step % 10 == 0:
# Per-layer gradient norms (find which layers have exploding gradients)
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_norm={param.grad.norm():.4f}")
**Gradient clipping is not optional for transformers.** Without it, a single bad batch can produce a gradient spike that destabilizes the entire training run. Common settings:
Model Typemax_normNotes
GPT/LLaMA (LLMs)1.0Standard; occasionally lowered to 0.5
BERT/RoBERTa1.0Standard
Vision Transformer1.0-5.0More tolerant
Diffusion models1.0Standard
RL (PPO, etc.)0.5-1.0More aggressive clipping

Monitor gradient norms. Plotting the gradient norm over training reveals:

  • Spikes: Potential instability. If spikes are frequent, reduce LR or increase clipping.
  • Gradual increase: Normal early in training as the model learns.
  • Sudden collapse to 0: Potential vanishing gradients. Check for dead ReLUs or exploded loss.

Training Diagnostics

MetricWhat to MonitorHealthy SignalWarning Sign
Train lossShould decreaseSmooth, monotonic decreaseOscillation, plateau, spike, NaN
Val lossShould decrease (slower than train)Follows train loss with small gapIncreases while train loss decreases (overfitting)
Gradient normShould be stable0.1-10, occasional small spikesPersistent spikes > 100, collapse to 0
Learning rateShould follow scheduleWarmup then decayStuck at wrong value (scheduler bug)
GPU utilizationShould be > 90%Consistently highDrops indicate data loading or sync issues
Memory usageShould be stableConstant after warmupMonotonic increase (memory leak)
Throughput (samples/sec)Should be stableConsistent across stepsDrops indicate I/O or communication issues