Skip to main content

Debugging

Debugging deep learning code is uniquely challenging. Unlike traditional software bugs that produce immediate errors, ML bugs often manifest as silently wrong results: the model trains, produces numbers, but those numbers are incorrect. Training might converge to a suboptimal loss, or the model might fail to generalize. This chapter covers the most common bugs, systematic debugging strategies, and the tools PyTorch provides for diagnosing issues.

The Debugging Hierarchy

When something goes wrong with training, work through this checklist in order:

PriorityCategoryExamplesHow to Detect
1Shape errorsWrong broadcasting, incorrect transposeAssertions, explicit shape checks
2Device errorsCPU/GPU mismatchRuntimeError messages
3Data bugsWrong labels, incorrect normalization, data leakageInspect random samples, check label distribution
4Numerical issuesNaN/Inf, gradient explosion/vanishingMonitor loss, grad norms, detect_anomaly()
5Hyperparameter issuesLR too high/low, wrong scheduleLoss curves, learning rate sweeps
6Architecture bugsWrong activation, missing layer, incorrect connectionOverfit a single batch first
7Training bugsNot zeroing grads, eval() not called, wrong lossCode review, unit tests

Anomaly Detection

PyTorch can track where NaN/Inf values originate by recording the operation that produced each tensor:


# Enable anomaly detection (adds significant overhead, debug only)
torch.autograd.set_detect_anomaly(True)

# As a context manager (scoped to specific code)
with torch.autograd.detect_anomaly():
output = model(input)
loss = criterion(output, target)
loss.backward()
# If any operation produces NaN, PyTorch will raise an error
# with the EXACT operation name and a traceback to where it was created

# Example error output:
# RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output.
# The above operation was created at:
# File "model.py", line 42, in forward
# return F.log_softmax(logits, dim=-1)
**Anomaly detection overhead.** When enabled, PyTorch stores the Python stack trace for every operation, which: - **Slows training by 2-5x** (extra memory allocation and bookkeeping) - **Increases memory usage** (stack traces for every tensor)

Use it only to locate the source of NaN/Inf. Once found, disable it and fix the underlying issue. Do not leave it enabled in production training.

Shape Debugging

Shape mismatches are the most common bug in deep learning code. They often produce incorrect results without raising an error (due to broadcasting):


class MyAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)

def forward(self, x):
B, T, D = x.shape
assert D == self.d_model, f"Expected d_model={self.d_model}, got {D}"

q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

# Reshape for multi-head attention
# (B, T, D) -> (B, T, H, D/H) -> (B, H, T, D/H)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

assert q.shape == (B, self.n_heads, T, self.head_dim), \
f"q shape mismatch: expected {(B, self.n_heads, T, self.head_dim)}, got {q.shape}"

attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
assert attn.shape == (B, self.n_heads, T, T), \
f"attn shape mismatch: expected {(B, self.n_heads, T, T)}, got {attn.shape}"

return attn

def debug_tensor(name, tensor):
"""Print comprehensive tensor diagnostics."""
stats = {
'shape': tensor.shape,
'dtype': tensor.dtype,
'device': tensor.device,
'min': tensor.min().item(),
'max': tensor.max().item(),
'mean': tensor.float().mean().item(),
'std': tensor.float().std().item(),
'nan': tensor.isnan().sum().item(),
'inf': tensor.isinf().sum().item(),
'requires_grad': tensor.requires_grad,
}
parts = [f"{name}:"]
parts.append(f" shape={stats['shape']}, dtype={stats['dtype']}, device={stats['device']}")
parts.append(f" range=[{stats['min']:.4f}, {stats['max']:.4f}], "
f"mean={stats['mean']:.4f}, std={stats['std']:.4f}")
if stats['nan'] > 0 or stats['inf'] > 0:
parts.append(f" WARNING: {stats['nan']} NaN, {stats['inf']} Inf values!")
print('\n'.join(parts))

# Usage in forward pass
def forward(self, x):
debug_tensor("input", x)
x = self.encoder(x)
debug_tensor("after_encoder", x)
x = self.classifier(x)
debug_tensor("logits", x)
return x

NaN Detection and Prevention


def check_for_nan(model, loss, step):
"""Check for NaN in loss, parameters, and gradients."""
issues = []

if torch.isnan(loss):
issues.append(f"NaN loss at step {step}")

if torch.isinf(loss):
issues.append(f"Inf loss at step {step} (value: {loss.item()})")

for name, param in model.named_parameters():
if torch.isnan(param).any():
issues.append(f"NaN in parameter: {name}")
if param.grad is not None:
if torch.isnan(param.grad).any():
issues.append(f"NaN in gradient: {name}")
if torch.isinf(param.grad).any():
issues.append(f"Inf in gradient: {name} (norm: {param.grad.norm():.2e})")

if issues:
for issue in issues:
print(f" [NaN CHECK] {issue}")
raise ValueError(f"Numerical issues detected at step {step}")

# Integrate into training loop
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)

# Check BEFORE backward (catches forward-pass NaN)
if step % 10 == 0:
check_for_nan(model, loss, step)

loss.backward()

# Check AFTER backward (catches backward-pass NaN)
if step % 10 == 0:
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in {name} at step {step}")
CauseHow It HappensSymptomFix
LR too highLarge weight update overshootsLoss spikes then NaNReduce LR; add warmup
Division by zero1/x where x can be 0NaN in specific layerAdd epsilon: x / (y + 1e-8)
Log of zero/negativelog(prob) where prob <= 0NaN in lossClamp: torch.log(x.clamp(min=1e-8))
Exp overflowexp(x) where x > 88 (FP32)Inf then NaNUse log-space: log_softmax instead of softmax + log
FP16 overflowGradients > 65504NaN with mixed precisionUse BF16, or increase GradScaler initial scale
FP16 underflowGradients below the smallest normal (6×105\approx 6 \times 10^{-5}) lose precision in the subnormal range; full flush-to-zero near 6×1086 \times 10^{-8}Gradients become 0, loss plateausUse BF16, or GradScaler
Unstable normalizationLayerNorm/BatchNorm with zero varianceNaN after norm layerIncrease epsilon in norm layers
Attention score overflowQKT/dQK^T/\sqrt{d} produces large valuesNaN after softmaxUse scaled dot-product attention, or clamp scores
Incorrect loss functionUsing loss that expects different inputPersistent NaNVerify loss function API (log-probs vs logits)
**Defensive coding patterns for numerical stability:**
# BAD: log can produce -inf or NaN
loss = -torch.log(probs)

# GOOD: clamp to prevent log(0)
loss = -torch.log(probs.clamp(min=1e-8))

# BETTER: use log_softmax (numerically stable log + softmax)
loss = F.cross_entropy(logits, targets) # Uses log_softmax internally

# BAD: softmax then log (two passes, less stable)
probs = F.softmax(logits, dim=-1)
loss = -torch.log(probs[range(B), targets])

# GOOD: log_softmax (one pass, numerically stable)
log_probs = F.log_softmax(logits, dim=-1)
loss = -log_probs[range(B), targets].mean()

# BAD: division can produce Inf
normalized = x / x.norm()

# GOOD: add epsilon to denominator
normalized = x / (x.norm() + 1e-8)

# BEST: use built-in normalized function
normalized = F.normalize(x, dim=-1) # Handles zero-norm internally

Symptom. A transformer trains fine for a few hundred steps, then the loss prints nan and never recovers. There is no Python exception: the model keeps running, but every subsequent loss is nan.

Step 1: localize the operation. Wrap one step in anomaly detection so PyTorch reports the exact backward op and the line where the offending tensor was created:

with torch.autograd.detect_anomaly():
output = model(input)
loss = criterion(output, target)
loss.backward()

The error points at LogBackward0 created inside a custom loss that computes -torch.log(probs).

Step 2: what is the underlying cause?

The custom loss takes log of a probability that can reach exactly 0 (for example, after an aggressive softmax saturates one logit). log(0) is -inf, and -inf propagated through the backward pass becomes nan in the gradients, which then corrupts the weights so that every later step produces nan. The forward loss looked finite at first only because the probability had not yet collapsed to zero.

Step 3: what is the fix?

Replace the manual log of a probability with a numerically stable formulation. Either clamp the argument, or better, fuse the log into the softmax with log_softmax / cross_entropy:

# Before: -torch.log(probs) can hit log(0) = -inf
# After: let PyTorch fuse log and softmax for stability
loss = F.cross_entropy(logits, targets)

After the change, rerun the overfit-one-batch test (below) to confirm the loss now descends smoothly instead of diverging to nan.

This is the canonical loop: anomaly detection turns a silent nan into a named operation and source line (the diagnosis), the NaN-causes table maps that operation to a root cause, and the defensive-coding patterns supply the fix.

Memory Debugging


import torch

# Basic memory reporting
def print_memory():
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
print(f"Max Alloc: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

# Track peak memory for a code block
torch.cuda.reset_peak_memory_stats()
# ... your code ...
peak = torch.cuda.max_memory_allocated() / 1e9
print(f"Peak memory: {peak:.2f} GB")

# Memory snapshot (PyTorch 2.1+) to visualize memory over time
torch.cuda.memory._record_memory_history(max_entries=100000)
# ... run training steps ...
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
torch.cuda.memory._record_memory_history(enabled=None) # Stop recording
# Visualize at: https://pytorch.org/memory_viz
ComponentMemoryFormulaTypical for 7B Model
Model parameters4P4P bytes (FP32) or 2P2P (BF16)(see formula)14 GB (BF16)
Gradients4P4P bytes (FP32) or 2P2P (BF16)Same as params14 GB (BF16)
Optimizer state (Adam)8P8P bytes (m + v, FP32)2x parameters56 GB
Activations (for backward)Depends on batch size and modelO(B×L×D2)O(B \times L \times D^2)10-50 GB
Total(sum of rows)~16P16P + activations94+ GB
**Assumptions behind the budget.** The table above assumes a pure-FP32 setup, where parameters and gradients are stored in FP32 ($4P$ each) and Adam keeps two FP32 moments ($8P$). In a mixed-precision setup the layout shifts: parameters and gradients are held in BF16 ($2P$ each), but the optimizer keeps an FP32 master copy of the weights ($4P$) alongside the two FP32 moments ($8P$). The BF16 rows in the last column reflect this mixed-precision layout, so a complete mixed-precision budget adds the master copy: for a 7B model that is roughly $4P \approx 28$ GB on top of the 14 + 14 + 56 GB already listed. Either way the per-parameter cost lands near $16P$ to $20P$ before activations. **Common memory leaks and their fixes:**
# LEAK 1: Storing loss tensors (keeps entire computation graph in GPU memory)
losses = []
for batch in loader:
loss = model(batch)
losses.append(loss) # BAD: keeps graph alive
# FIX: losses.append(loss.item()) # .item() extracts Python scalar

# LEAK 2: Not clearing CUDA cache after large allocations
big_tensor = torch.randn(10000, 10000, device='cuda')
del big_tensor
# Memory is returned to PyTorch's cache but NOT to the OS
torch.cuda.empty_cache() # Returns memory to CUDA driver (rarely needed)

# LEAK 3: Holding references to intermediate tensors
class BadModel(nn.Module):
def forward(self, x):
self.last_hidden = self.encoder(x) # BAD: stored as attribute
return self.decoder(self.last_hidden)
# FIX: Don't store intermediates, or detach them:
# self.last_hidden = self.encoder(x).detach()

# LEAK 4: Forgetting to delete tensors in evaluation
@torch.no_grad()
def evaluate(model, loader):
all_preds = []
for batch in loader:
output = model(batch.cuda())
all_preds.append(output.cpu()) # Move to CPU to free GPU memory
return torch.cat(all_preds)

The "Overfit One Batch" Test

The single most useful debugging technique: verify that your model can perfectly fit a single batch of training data.


def overfit_one_batch(model, dataset, device='cuda', steps=1000, lr=1e-3):
"""Verify model can memorize a single batch.
If this fails, the model or training code has a bug."""

model = model.to(device).train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# Get one batch
loader = DataLoader(dataset, batch_size=32, shuffle=False)
data, target = next(iter(loader))
data, target = data.to(device), target.to(device)

for step in range(steps):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

if step % 100 == 0:
acc = (output.argmax(1) == target).float().mean()
print(f"Step {step}: loss={loss.item():.4f}, acc={acc:.2%}")

# Should reach near-zero loss and ~100% accuracy
final_acc = (model(data).argmax(1) == target).float().mean()
assert final_acc > 0.99, f"Failed to overfit one batch! Final acc: {final_acc:.2%}"
print("PASSED: model can overfit one batch")

# Run this BEFORE full training to catch bugs early
overfit_one_batch(model, train_dataset)
**If overfit-one-batch fails, the bug is in your model or training code**, not in your data or hyperparameters. Common causes:
SymptomLikely Cause
Loss does not decrease at allGradients not flowing (detach, wrong loss, frozen params)
Loss decreases but accuracy stays at randomWrong loss function for the task
Loss goes to NaN immediatelyNumerical issue (see NaN table above)
Loss decreases very slowlyLR too low, or model too small for the task
Loss oscillates wildlyLR too high
Accuracy saturates below 100%Architecture cannot represent the mapping (too few params, wrong activation)

Reproducibility


import torch
import numpy as np
import random

def set_seed(seed=42):
"""Set all random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Make CuDNN deterministic (may reduce performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Enable deterministic algorithms (PyTorch 2.0+)
torch.use_deterministic_algorithms(True)
# Note: some operations do not have deterministic implementations
# and will raise an error. Set CUBLAS_WORKSPACE_CONFIG=:4096:8

set_seed(42)
**Full reproducibility requires controlling all sources of randomness:**
SourceHow to ControlPerformance Impact
Python randomrandom.seed(seed)None
NumPynp.random.seed(seed)None
PyTorch CPUtorch.manual_seed(seed)None
PyTorch GPUtorch.cuda.manual_seed_all(seed)None
CuDNN autotunertorch.backends.cudnn.benchmark = FalseUp to 2x slower (first few iterations)
CuDNN algorithmstorch.backends.cudnn.deterministic = TrueSlight slowdown
Nondeterministic opstorch.use_deterministic_algorithms(True)Some ops slower or error
Data loading orderDataLoader(shuffle=False) or seeded generatorNone
Multi-worker loadingworker_init_fn with deterministic seedsNone
DropoutControlled by PyTorch seedNone
Data augmentationControlled by NumPy/random seedNone

In practice: Full determinism is useful for debugging but unnecessary (and slower) for production training. Use it to verify that a code change does not break training, then disable for actual runs.

Debugging Checklist

StepActionCatches
1Overfit one batchModel/code bugs, wrong loss function
2Check input data (visualize samples, verify labels)Data pipeline bugs, wrong preprocessing
3Verify shapes at each layerBroadcasting bugs, wrong dimensions
4Monitor loss, gradient norms, LR per stepTraining instability, bad hyperparams
5Compare against a known baselineArchitecture bugs, missing components
6Train on a small subset firstVerify learning before spending compute
7Check train vs val gapOverfitting or underfitting
8Profile with PyTorch profilerPerformance bottlenecks