Autograd
Automatic differentiation (autograd) is the engine that makes training neural networks possible. When you call loss.backward(), PyTorch traces back through every operation that produced the loss, computing the gradient of the loss with respect to every parameter. This chapter explains how the computational graph works, how gradients are computed and accumulated, and the common patterns and pitfalls you will encounter.
Computational Graphs
PyTorch builds a dynamic computational graph during the forward pass. Each tensor operation creates a node in the graph, and edges track dependencies. The graph is then traversed in reverse (backpropagation) to compute gradients.
import torch
# Leaf tensors with requires_grad=True are the "roots" of the graph
x = torch.randn(3, requires_grad=True) # Leaf tensor: tracked by autograd
w = torch.randn(3, requires_grad=True) # Leaf tensor: model parameter
# Each operation creates a new node with a grad_fn
y = x * w # y.grad_fn = <MulBackward0>
z = y.sum() # z.grad_fn = <SumBackward0>
# The graph connects z -> y -> (x, w)
print(z.grad_fn) # <SumBackward0>
print(z.grad_fn.next_functions) # Points to MulBackward0
print(z.grad_fn.next_functions[0][0].next_functions) # Points to leaf accumulators
# Leaf tensors have no grad_fn (they are inputs, not computed)
print(x.grad_fn) # None
print(x.is_leaf) # True
print(y.is_leaf) # False
| Property | Dynamic (PyTorch) | Static (TensorFlow 1.x, JAX) |
|---|---|---|
| Graph construction | Every forward pass | Once, then reused |
| Python control flow | Fully supported (if, for, etc.) | Requires special primitives |
| Debugging | Standard Python debugger (pdb, breakpoints) | Difficult (graph is compiled) |
| Variable-length inputs | Natural (different graph each time) | Requires padding/bucketing |
| Compilation | torch.compile adds JIT compilation | Built-in |
| Overhead | Higher per-op (Python dispatch) | Lower per-op (precompiled) |
The dynamic graph is why PyTorch is easy to debug and prototype with. torch.compile bridges the gap by capturing the graph for JIT optimization while preserving the eager programming model.
backward() and Gradient Computation
Calling .backward() on a scalar tensor computes gradients for all leaf tensors with requires_grad=True:
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = (x ** 2).sum() # y = x1^2 + x2^2 + x3^2
y.backward() # Compute dy/dx using the chain rule
print(x.grad) # tensor([2., 4., 6.]) -- dy/dxi = 2*xi
# For non-scalar outputs, you must provide a gradient argument
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2 # y is a vector, not a scalar
# y.backward() # ERROR: grad can be implicitly created only for scalar outputs
y.backward(torch.ones_like(y)) # Equivalent to (y * ones).sum().backward()
print(x.grad) # tensor([2., 4., 6.])
Forward: x → y = x*w → z = y.sum()
grad_fn: grad_fn:
MulBackward0 SumBackward0
Backward: dz/dx ← dz/dy * dy/dx ← dz/dz = 1
= 1 * w = w (for dz/dx)
= 1 * x = x (for dz/dw)
Each node receives the gradient from downstream (, called grad_output) and computes the gradient to pass upstream () by multiplying with the local derivative ().
Gradient Accumulation
x = torch.tensor([1.0], requires_grad=True)
y = (x ** 2).sum()
y.backward()
print(x.grad) # tensor([2.])
y = (x ** 2).sum()
y.backward()
print(x.grad) # tensor([4.]) -- ACCUMULATED! Not [2.]
Why this design? Accumulation enables gradient accumulation across mini-batches (see below) and is needed for models where a parameter is used multiple times in the forward pass (e.g., weight tying).
# WRONG: gradients accumulate across iterations
for data, target in dataloader:
loss = criterion(model(data), target)
loss.backward() # Gradients ADD to existing .grad
optimizer.step() # Update with accumulated (wrong) gradients
# CORRECT: zero gradients first
for data, target in dataloader:
optimizer.zero_grad() # Reset all .grad to zero (or None)
loss = criterion(model(data), target)
loss.backward()
optimizer.step()
# EVEN BETTER: set_to_none=True (default in PyTorch 2.0+)
optimizer.zero_grad(set_to_none=True)
# Sets .grad to None instead of zero tensor -- saves memory and is slightly faster
Gradient Accumulation for Large Effective Batch Sizes
Simulate larger batch sizes by accumulating gradients over multiple mini-batches without increasing memory usage:
accumulation_steps = 4 # Effective batch = 4 * batch_size
for i, (data, target) in enumerate(dataloader):
# Forward pass
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(data)
loss = criterion(output, target)
loss = loss / accumulation_steps # Normalize loss to get correct average
# Backward pass (gradients accumulate)
loss.backward()
# Update only every accumulation_steps
if (i + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
# Memory usage: same as batch_size (not accumulation_steps * batch_size)
# Gradient quality: equivalent to the larger batch size
For most workloads (especially with LayerNorm), these differences are negligible.
torch.no_grad() and Inference Mode
Disable gradient tracking to save memory and speed up inference:
# torch.no_grad(): disables gradient computation
with torch.no_grad():
output = model(input) # No grad_fn attached, no graph built
output.requires_grad # False
# torch.inference_mode(): even faster (also disables view tracking)
with torch.inference_mode():
output = model(input) # Cannot create views that require grad tracking
# As a decorator
@torch.inference_mode()
def predict(model, input):
return model(input)
# Memory savings: significant
# Without no_grad: forward pass stores activations for backward (2-3x model memory)
# With no_grad: only the output tensor is allocated
| Feature | no_grad() | inference_mode() |
|---|---|---|
| Disables gradient computation | Yes | Yes |
| Disables view tracking | No | Yes |
| Speed improvement | Moderate | Better |
| Can create tensors that later need gradients | Yes | No |
| Use case | Evaluation during training | Pure inference |
Use inference_mode() for deployment and pure inference. Use no_grad() during training evaluation where you might need to create tensors that interact with the training graph later.
Detach, Retain Graph, and Stop Gradient
# detach(): creates a tensor that shares data but has no grad_fn
x = torch.randn(3, requires_grad=True)
y = x * 2
z = y.detach() # z has the same values as y, but no connection to x
z.requires_grad # False
# Modifying z does NOT affect x's gradients
# Common use: target networks in RL / self-supervised learning
target = model(x).detach() # Stop gradients from flowing through target
# retain_graph: keep the graph for multiple backward passes
x = torch.randn(3, requires_grad=True)
y = (x ** 2).sum()
y.backward(retain_graph=True) # Graph is kept (not freed)
print(x.grad) # tensor([...])
y.backward() # Second backward (gradients accumulate!)
print(x.grad) # Previous grad + new grad
# Use case: computing multiple losses that share computation
shared = encoder(input)
loss_1 = criterion_1(head_1(shared))
loss_2 = criterion_2(head_2(shared))
loss_1.backward(retain_graph=True) # Keep graph for loss_2
loss_2.backward() # Now graph is freed
| Pattern | Code | Purpose |
|---|---|---|
| Stop gradient | target = model(x).detach() | Prevent gradient flow (EMA targets, contrastive learning) |
| Logging scalars | losses.append(loss.detach().item()) | Avoid keeping computation graph in memory |
| Straight-through estimator | y_hard = y_soft.detach() + (quantize(y_soft) - y_soft).detach() | Discrete forward, continuous backward |
| Gradient penalty | x_interp = (a*real + (1-a)*fake).detach().requires_grad_(True) | Compute gradients w.r.t. interpolated input |
Custom Autograd Functions
When you need to define custom forward and backward behavior -- for non-standard operations, memory-efficient backward passes, or interfacing with non-PyTorch code:
class MySiLU(torch.autograd.Function):
"""SiLU (Swish) activation: x * sigmoid(x)."""
@staticmethod
def forward(ctx, x):
sigmoid_x = torch.sigmoid(x)
ctx.save_for_backward(x, sigmoid_x) # Save for backward
return x * sigmoid_x
@staticmethod
def backward(ctx, grad_output):
x, sigmoid_x = ctx.saved_tensors
# d/dx [x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
# = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
grad_input = grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
return grad_input # One output for each input to forward
# Usage (always use .apply, never call forward directly)
x = torch.randn(5, requires_grad=True)
y = MySiLU.apply(x)
y.sum().backward()
print(x.grad)
# Verify correctness with gradcheck
from torch.autograd import gradcheck
x = torch.randn(5, dtype=torch.float64, requires_grad=True)
assert gradcheck(MySiLU.apply, (x,), eps=1e-6)
-
Use
ctx.save_for_backward()-- notctx.x = x. PyTorch manages memory efficiently for saved tensors, including releasing them after backward and supporting gradient checkpointing. -
Match the number of returns to forward inputs.
backward()must return one gradient for each input toforward(). Usereturn Nonefor inputs that do not need gradients (e.g., integer arguments). -
Use
gradcheck()to verify correctness. It computes numerical gradients and compares them to your analytical backward. Always test withfloat64for numerical precision:from torch.autograd import gradcheckx = torch.randn(5, dtype=torch.float64, requires_grad=True)assert gradcheck(MyFunction.apply, (x,), eps=1e-6, atol=1e-4) -
Memory-efficient backward: If the forward pass produces large intermediate tensors, you can recompute them in the backward pass instead of saving them. This trades compute for memory:
@staticmethoddef forward(ctx, x):ctx.save_for_backward(x) # Save only input, not intermediatesreturn expensive_computation(x)@staticmethoddef backward(ctx, grad_output):x, = ctx.saved_tensors# Recompute intermediates from x (costs compute, saves memory)intermediate = expensive_computation_part1(x)return grad_output * gradient_formula(intermediate)
Gradient Checkpointing
Trade compute for memory by recomputing activations during the backward pass instead of storing them:
from torch.utils.checkpoint import checkpoint
class DeepModel(nn.Module):
def __init__(self, num_layers):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model=1024) for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
# Without checkpointing: stores activations for all layers
# x = layer(x)
# With checkpointing: only stores input; recomputes activations in backward
x = checkpoint(layer, x, use_reentrant=False)
return x
# Memory savings:
# Without checkpointing: O(L * activation_size) -- linear in depth
# With checkpointing: O(sqrt(L) * activation_size) -- if checkpointed every sqrt(L) layers
# Compute cost: ~30% more (one extra forward pass per checkpointed segment)
Most large model training (GPT, LLaMA, etc.) uses gradient checkpointing. It is essentially free in wall-clock time because the memory savings allow larger batches, which improve GPU utilization.
Autograd Debugging Checklist
| Error | Cause | Fix |
|---|---|---|
grad can be implicitly created only for scalar outputs | Calling .backward() on a non-scalar tensor | Pass gradient= argument, or reduce to scalar first |
Trying to backward through the graph a second time | Graph freed after first .backward() | Use retain_graph=True (or restructure code) |
one of the variables needed for gradient computation has been modified by an inplace operation | In-place operation on a tensor in the graph | Replace x.add_(1) with x = x + 1 |
element 0 of tensors does not require grad and does not have a grad_fn | No path from output to any requires_grad=True tensor | Check that model parameters have requires_grad=True |
| Gradients are all zero | Using .detach() or .data incorrectly | Check that no detach breaks the gradient path |
| Gradients are unexpectedly large | Not zeroing gradients between steps | Add optimizer.zero_grad() before backward |