Gradient Descent
The Optimization Problem
Machine learning reduces to empirical risk minimization (ERM): finding parameters that minimize a loss function averaged over a training dataset :
For most models of interest (neural networks), is non-convex and high-dimensional (billions of parameters), ruling out closed-form solutions. Gradient descent solves this iteratively by moving in the direction of steepest descent.
Vanilla Gradient Descent
where is the learning rate (step size).
Convergence Analysis
Equivalently, the loss is bounded by a quadratic around any point:
The constant is the largest eigenvalue of the Hessian .
This is an rate: to achieve -accuracy, we need iterations.
The condition number determines convergence speed. Poorly conditioned problems () converge slowly because the loss surface is elongated: the gradient oscillates across the narrow direction while making slow progress along the wide direction.
Stochastic Gradient Descent (SGD)
Properties of the stochastic gradient :
- Unbiased:
- Bounded variance: , where is the per-sample gradient variance
- Computational cost: per step instead of
This is an rate to a stationary point (where ). Note: for non-convex functions, a stationary point may be a saddle point or local minimum, not necessarily a global minimum.
- SGD noise helps escape sharp minima (high curvature) in favor of flat minima (low curvature), which tend to generalize better ([?keskar2017large]).
- The noise scale is proportional to (learning rate divided by batch size), which is why the linear scaling rule ([?goyal2017accurate]) prescribes scaling proportionally to .
- Large-batch training (small noise) can converge to sharp minima that generalize poorly, unless carefully tuned with warmup and learning rate schedules.
Momentum
SGD oscillates in directions with high curvature (large eigenvalues of the Hessian) while making slow progress along directions with low curvature. Momentum fixes this by accumulating a velocity vector:
where is the momentum coefficient (typically ). The velocity is an exponential moving average of past gradients with effective window size .
Nesterov accelerated gradient evaluates the gradient at the "look-ahead" position :
Nesterov momentum achieves the optimal convergence rate for smooth convex functions, compared to for vanilla GD ([?nesterov1983method]). In practice, the improvement over classical momentum is often modest for deep learning.
Adam (Adaptive Moment Estimation)
Bias-corrected estimates (since ):
Parameter update:
Default hyperparameters: , , .
- Parameters with large gradient variance (noisy signal) have large , so their effective learning rate is smaller -- Adam is cautious where the signal is noisy.
- Parameters with small gradient variance (consistent signal) have small , so their effective learning rate is larger -- Adam makes confident updates.
- The bias correction is essential in the first few steps: without it, and are biased toward zero because they are initialized at zero and the exponential average has not converged.
AdamW is the standard optimizer for transformer training. Typical hyperparameters: , , , weight decay .
Other Optimizers
| Optimizer | Update Rule (Simplified) | Memory | Key Property |
|---|---|---|---|
| SGD | Simple, good generalization | ||
| SGD + Momentum | ; | Dampens oscillations | |
| Adam | Uses as above | Adaptive per-parameter rates | |
| AdamW | Adam + decoupled weight decay | Standard for transformers | |
| Adafactor | Factored second moments | Memory-efficient for large matrices | |
| LAMB | Adam + layerwise LR scaling | Large-batch training | |
| Lion | Sign-based momentum update | Memory-efficient, empirically strong | |
| Muon | Orthogonalized momentum | Emerging; strong empirical results |
Learning Rate Schedules
The learning rate is rarely constant. A schedule adapts it over training:
| Schedule | Formula | Typical Use |
|---|---|---|
| Constant | Baselines, short fine-tuning | |
| Step decay | CNNs (ResNet); , at 30/60/90 epochs | |
| Cosine annealing | Transformer pretraining ([?loshchilov2017sgdr]) | |
| Linear warmup + cosine | Standard for LLM pretraining | |
| Warmup + inverse sqrt | Original Transformer (Vaswani et al., 2017) | |
| WSD (Warmup-Stable-Decay) | Warmup constant cosine decay | Practical for unknown training length |
Gradient Clipping
where is the maximum allowed gradient norm (typically ). This preserves the gradient direction but bounds its magnitude.
Convergence Rate Summary
| Setting | Algorithm | Rate | Metric |
|---|---|---|---|
| -smooth convex | GD () | ||
| -smooth, -strongly convex | GD () | ||
| -smooth convex | Nesterov | ||
| -smooth convex | SGD () | ||
| -smooth non-convex | SGD () |
Notation Summary
| Symbol | Meaning |
|---|---|
| Model parameters | |
| Learning rate (possibly time-varying) | |
| Loss function (empirical risk) | |
| Per-sample loss | |
| Gradient (or stochastic gradient) at step | |
| Velocity (momentum) or second moment (Adam) | |
| First moment estimate (Adam) | |
| Bias-corrected moment estimates | |
| Exponential decay rates (Adam) | |
| Momentum coefficient | |
| Lipschitz constant of the gradient (-smoothness) | |
| (strong convexity) | Strong convexity parameter |
| Condition number | |
| $B = | \mathcal{B} |
| Per-sample gradient variance | |
| Gradient clipping threshold |