Skip to main content

Debugging

Debugging deep learning code is uniquely challenging. Unlike traditional software bugs that produce immediate errors, ML bugs often manifest as silently wrong results -- the model trains, produces numbers, but those numbers are incorrect. Training might converge to a suboptimal loss, or the model might fail to generalize. This chapter covers the most common bugs, systematic debugging strategies, and the tools PyTorch provides for diagnosing issues.

The Debugging Hierarchy

When something goes wrong with training, work through this checklist in order:

PriorityCategoryExamplesHow to Detect
1Shape errorsWrong broadcasting, incorrect transposeAssertions, explicit shape checks
2Device errorsCPU/GPU mismatchRuntimeError messages
3Data bugsWrong labels, incorrect normalization, data leakageInspect random samples, check label distribution
4Numerical issuesNaN/Inf, gradient explosion/vanishingMonitor loss, grad norms, detect_anomaly()
5Hyperparameter issuesLR too high/low, wrong scheduleLoss curves, learning rate sweeps
6Architecture bugsWrong activation, missing layer, incorrect connectionOverfit a single batch first
7Training bugsNot zeroing grads, eval() not called, wrong lossCode review, unit tests

Anomaly Detection

PyTorch can track where NaN/Inf values originate by recording the operation that produced each tensor:


# Enable anomaly detection (adds significant overhead -- debug only)
torch.autograd.set_detect_anomaly(True)

# As a context manager (scoped to specific code)
with torch.autograd.detect_anomaly():
output = model(input)
loss = criterion(output, target)
loss.backward()
# If any operation produces NaN, PyTorch will raise an error
# with the EXACT operation name and a traceback to where it was created

# Example error output:
# RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output.
# The above operation was created at:
# File "model.py", line 42, in forward
# return F.log_softmax(logits, dim=-1)
**Anomaly detection overhead.** When enabled, PyTorch stores the Python stack trace for every operation, which: - **Slows training by 2-5x** (extra memory allocation and bookkeeping) - **Increases memory usage** (stack traces for every tensor)

Use it only to locate the source of NaN/Inf. Once found, disable it and fix the underlying issue. Do not leave it enabled in production training.

Shape Debugging

Shape mismatches are the most common bug in deep learning code. They often produce incorrect results without raising an error (due to broadcasting):


class MyAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)

def forward(self, x):
B, T, D = x.shape
assert D == self.d_model, f"Expected d_model={self.d_model}, got {D}"

q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

# Reshape for multi-head attention
# (B, T, D) -> (B, T, H, D/H) -> (B, H, T, D/H)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

assert q.shape == (B, self.n_heads, T, self.head_dim), \
f"q shape mismatch: expected {(B, self.n_heads, T, self.head_dim)}, got {q.shape}"

attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
assert attn.shape == (B, self.n_heads, T, T), \
f"attn shape mismatch: expected {(B, self.n_heads, T, T)}, got {attn.shape}"

return attn

def debug_tensor(name, tensor):
"""Print comprehensive tensor diagnostics."""
stats = {
'shape': tensor.shape,
'dtype': tensor.dtype,
'device': tensor.device,
'min': tensor.min().item(),
'max': tensor.max().item(),
'mean': tensor.float().mean().item(),
'std': tensor.float().std().item(),
'nan': tensor.isnan().sum().item(),
'inf': tensor.isinf().sum().item(),
'requires_grad': tensor.requires_grad,
}
parts = [f"{name}:"]
parts.append(f" shape={stats['shape']}, dtype={stats['dtype']}, device={stats['device']}")
parts.append(f" range=[{stats['min']:.4f}, {stats['max']:.4f}], "
f"mean={stats['mean']:.4f}, std={stats['std']:.4f}")
if stats['nan'] > 0 or stats['inf'] > 0:
parts.append(f" WARNING: {stats['nan']} NaN, {stats['inf']} Inf values!")
print('\n'.join(parts))

# Usage in forward pass
def forward(self, x):
debug_tensor("input", x)
x = self.encoder(x)
debug_tensor("after_encoder", x)
x = self.classifier(x)
debug_tensor("logits", x)
return x

NaN Detection and Prevention


def check_for_nan(model, loss, step):
"""Check for NaN in loss, parameters, and gradients."""
issues = []

if torch.isnan(loss):
issues.append(f"NaN loss at step {step}")

if torch.isinf(loss):
issues.append(f"Inf loss at step {step} (value: {loss.item()})")

for name, param in model.named_parameters():
if torch.isnan(param).any():
issues.append(f"NaN in parameter: {name}")
if param.grad is not None:
if torch.isnan(param.grad).any():
issues.append(f"NaN in gradient: {name}")
if torch.isinf(param.grad).any():
issues.append(f"Inf in gradient: {name} (norm: {param.grad.norm():.2e})")

if issues:
for issue in issues:
print(f" [NaN CHECK] {issue}")
raise ValueError(f"Numerical issues detected at step {step}")

# Integrate into training loop
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)

# Check BEFORE backward (catches forward-pass NaN)
if step % 10 == 0:
check_for_nan(model, loss, step)

loss.backward()

# Check AFTER backward (catches backward-pass NaN)
if step % 10 == 0:
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in {name} at step {step}")
CauseHow It HappensSymptomFix
LR too highLarge weight update overshootsLoss spikes then NaNReduce LR; add warmup
Division by zero1/x where x can be 0NaN in specific layerAdd epsilon: x / (y + 1e-8)
Log of zero/negativelog(prob) where prob <= 0NaN in lossClamp: torch.log(x.clamp(min=1e-8))
Exp overflowexp(x) where x > 88 (FP32)Inf then NaNUse log-space: log_softmax instead of softmax + log
FP16 overflowGradients > 65504NaN with mixed precisionUse BF16, or increase GradScaler initial scale
FP16 underflowGradients < 6×1086 \times 10^{-8}Gradients become 0, loss plateausUse BF16, or GradScaler
Unstable normalizationLayerNorm/BatchNorm with zero varianceNaN after norm layerIncrease epsilon in norm layers
Attention score overflowQKT/dQK^T/\sqrt{d} produces large valuesNaN after softmaxUse scaled dot-product attention, or clamp scores
Incorrect loss functionUsing loss that expects different inputPersistent NaNVerify loss function API (log-probs vs logits)
**Defensive coding patterns for numerical stability:**
# BAD: log can produce -inf or NaN
loss = -torch.log(probs)

# GOOD: clamp to prevent log(0)
loss = -torch.log(probs.clamp(min=1e-8))

# BETTER: use log_softmax (numerically stable log + softmax)
loss = F.cross_entropy(logits, targets) # Uses log_softmax internally

# BAD: softmax then log (two passes, less stable)
probs = F.softmax(logits, dim=-1)
loss = -torch.log(probs[range(B), targets])

# GOOD: log_softmax (one pass, numerically stable)
log_probs = F.log_softmax(logits, dim=-1)
loss = -log_probs[range(B), targets].mean()

# BAD: division can produce Inf
normalized = x / x.norm()

# GOOD: add epsilon to denominator
normalized = x / (x.norm() + 1e-8)

# BEST: use built-in normalized function
normalized = F.normalize(x, dim=-1) # Handles zero-norm internally

Memory Debugging


import torch

# Basic memory reporting
def print_memory():
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
print(f"Max Alloc: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

# Track peak memory for a code block
torch.cuda.reset_peak_memory_stats()
# ... your code ...
peak = torch.cuda.max_memory_allocated() / 1e9
print(f"Peak memory: {peak:.2f} GB")

# Memory snapshot (PyTorch 2.1+) -- visualize memory over time
torch.cuda.memory._record_memory_history(max_entries=100000)
# ... run training steps ...
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
torch.cuda.memory._record_memory_history(enabled=None) # Stop recording
# Visualize at: https://pytorch.org/memory_viz
ComponentMemoryFormulaTypical for 7B Model
Model parameters4P4P bytes (FP32) or 2P2P (BF16)-14 GB (FP32)
Gradients4P4P bytes (FP32) or 2P2P (BF16)Same as params14 GB (FP32)
Optimizer state (Adam)8P8P bytes (m + v, FP32)2x parameters56 GB
Activations (for backward)Depends on batch size and modelO(B×L×D2)O(B \times L \times D^2)10-50 GB
Total-~16P16P + activations94+ GB
**Common memory leaks and their fixes:**
# LEAK 1: Storing loss tensors (keeps entire computation graph in GPU memory)
losses = []
for batch in loader:
loss = model(batch)
losses.append(loss) # BAD: keeps graph alive
# FIX: losses.append(loss.item()) # .item() extracts Python scalar

# LEAK 2: Not clearing CUDA cache after large allocations
big_tensor = torch.randn(10000, 10000, device='cuda')
del big_tensor
# Memory is returned to PyTorch's cache but NOT to the OS
torch.cuda.empty_cache() # Returns memory to CUDA driver (rarely needed)

# LEAK 3: Holding references to intermediate tensors
class BadModel(nn.Module):
def forward(self, x):
self.last_hidden = self.encoder(x) # BAD: stored as attribute
return self.decoder(self.last_hidden)
# FIX: Don't store intermediates, or detach them:
# self.last_hidden = self.encoder(x).detach()

# LEAK 4: Forgetting to delete tensors in evaluation
@torch.no_grad()
def evaluate(model, loader):
all_preds = []
for batch in loader:
output = model(batch.cuda())
all_preds.append(output.cpu()) # Move to CPU to free GPU memory
return torch.cat(all_preds)

The "Overfit One Batch" Test

The single most useful debugging technique: verify that your model can perfectly fit a single batch of training data.


def overfit_one_batch(model, dataset, device='cuda', steps=1000, lr=1e-3):
"""Verify model can memorize a single batch.
If this fails, the model or training code has a bug."""

model = model.to(device).train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# Get one batch
loader = DataLoader(dataset, batch_size=32, shuffle=False)
data, target = next(iter(loader))
data, target = data.to(device), target.to(device)

for step in range(steps):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

if step % 100 == 0:
acc = (output.argmax(1) == target).float().mean()
print(f"Step {step}: loss={loss.item():.4f}, acc={acc:.2%}")

# Should reach near-zero loss and ~100% accuracy
final_acc = (model(data).argmax(1) == target).float().mean()
assert final_acc > 0.99, f"Failed to overfit one batch! Final acc: {final_acc:.2%}"
print("PASSED: model can overfit one batch")

# Run this BEFORE full training to catch bugs early
overfit_one_batch(model, train_dataset)
**If overfit-one-batch fails, the bug is in your model or training code**, not in your data or hyperparameters. Common causes:
SymptomLikely Cause
Loss does not decrease at allGradients not flowing (detach, wrong loss, frozen params)
Loss decreases but accuracy stays at randomWrong loss function for the task
Loss goes to NaN immediatelyNumerical issue (see NaN table above)
Loss decreases very slowlyLR too low, or model too small for the task
Loss oscillates wildlyLR too high
Accuracy saturates below 100%Architecture cannot represent the mapping (too few params, wrong activation)

Reproducibility


import torch
import numpy as np
import random

def set_seed(seed=42):
"""Set all random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Make CuDNN deterministic (may reduce performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Enable deterministic algorithms (PyTorch 2.0+)
torch.use_deterministic_algorithms(True)
# Note: some operations do not have deterministic implementations
# and will raise an error. Set CUBLAS_WORKSPACE_CONFIG=:4096:8

set_seed(42)
**Full reproducibility requires controlling all sources of randomness:**
SourceHow to ControlPerformance Impact
Python randomrandom.seed(seed)None
NumPynp.random.seed(seed)None
PyTorch CPUtorch.manual_seed(seed)None
PyTorch GPUtorch.cuda.manual_seed_all(seed)None
CuDNN autotunertorch.backends.cudnn.benchmark = FalseUp to 2x slower (first few iterations)
CuDNN algorithmstorch.backends.cudnn.deterministic = TrueSlight slowdown
Nondeterministic opstorch.use_deterministic_algorithms(True)Some ops slower or error
Data loading orderDataLoader(shuffle=False) or seeded generatorNone
Multi-worker loadingworker_init_fn with deterministic seedsNone
DropoutControlled by PyTorch seedNone
Data augmentationControlled by NumPy/random seedNone

In practice: Full determinism is useful for debugging but unnecessary (and slower) for production training. Use it to verify that a code change does not break training, then disable for actual runs.

Debugging Checklist

StepActionCatches
1Overfit one batchModel/code bugs, wrong loss function
2Check input data (visualize samples, verify labels)Data pipeline bugs, wrong preprocessing
3Verify shapes at each layerBroadcasting bugs, wrong dimensions
4Monitor loss, gradient norms, LR per stepTraining instability, bad hyperparams
5Compare against a known baselineArchitecture bugs, missing components
6Train on a small subset firstVerify learning before spending compute
7Check train vs val gapOverfitting or underfitting
8Profile with PyTorch profilerPerformance bottlenecks