Skip to main content

Gradient Descent

The Optimization Problem

Machine learning reduces to empirical risk minimization (ERM): finding parameters θ\theta that minimize a loss function averaged over a training dataset D={(xi,yi)}i=1N\mathcal{D} = \{(x_i, y_i)\}_{i=1}^N:

θ=argminθ  L(θ)=argminθ  1Ni=1N(fθ(xi),yi)\theta^* = \arg\min_\theta \; \mathcal{L}(\theta) = \arg\min_\theta \; \frac{1}{N} \sum_{i=1}^{N} \ell(f_\theta(x_i), y_i)

For most models of interest (neural networks), L(θ)\mathcal{L}(\theta) is non-convex and high-dimensional (billions of parameters), ruling out closed-form solutions. Gradient descent solves this iteratively by moving in the direction of steepest descent.

Vanilla Gradient Descent

Starting from an initial $\theta_0$, the update rule computes the gradient over the full dataset and takes a step in the negative gradient direction:

θt+1=θtηθL(θt)=θtηNi=1Nθ(fθt(xi),yi)\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t) = \theta_t - \frac{\eta}{N} \sum_{i=1}^{N} \nabla_\theta \ell(f_{\theta_t}(x_i), y_i)

where η>0\eta > 0 is the learning rate (step size).

**Geometric interpretation.** At each point $\theta_t$, the gradient $\nabla \mathcal{L}(\theta_t)$ is a vector pointing in the direction of steepest *increase* of $\mathcal{L}$. Moving in the direction $-\nabla \mathcal{L}$ is the locally optimal descent direction. However, "locally optimal" can be misleading: the gradient only gives a good direction for infinitesimally small steps. For finite step sizes, the curvature of $\mathcal{L}$ matters.

Convergence Analysis

A differentiable function $\mathcal{L}$ has **$L$-Lipschitz continuous gradients** (is $L$-smooth) if:

L(θ)L(θ)Lθθθ,θ\|\nabla \mathcal{L}(\theta) - \nabla \mathcal{L}(\theta')\| \leq L \|\theta - \theta'\| \quad \forall \, \theta, \theta'

Equivalently, the loss is bounded by a quadratic around any point:

L(θ)L(θ)+L(θ)(θθ)+L2θθ2\mathcal{L}(\theta') \leq \mathcal{L}(\theta) + \nabla \mathcal{L}(\theta)^\top (\theta' - \theta) + \frac{L}{2}\|\theta' - \theta\|^2

The constant LL is the largest eigenvalue of the Hessian 2L\nabla^2 \mathcal{L}.

For an $L$-smooth convex function, gradient descent with step size $\eta = 1/L$ satisfies:

L(θT)L(θ)Lθ0θ22T\mathcal{L}(\theta_T) - \mathcal{L}(\theta^*) \leq \frac{L \|\theta_0 - \theta^*\|^2}{2T}

This is an O(1/T)O(1/T) rate: to achieve ϵ\epsilon-accuracy, we need T=O(Lθ0θ2/ϵ)T = O(L \|\theta_0 - \theta^*\|^2 / \epsilon) iterations.

If $\mathcal{L}$ is additionally **$\mu$-strongly convex** ($\nabla^2 \mathcal{L} \succeq \mu I$), gradient descent with $\eta = 1/L$ converges linearly:

L(θT)L(θ)(1μL)T(L(θ0)L(θ))\mathcal{L}(\theta_T) - \mathcal{L}(\theta^*) \leq \left(1 - \frac{\mu}{L}\right)^T \left(\mathcal{L}(\theta_0) - \mathcal{L}(\theta^*)\right)

The condition number κ=L/μ\kappa = L/\mu determines convergence speed. Poorly conditioned problems (κ1\kappa \gg 1) converge slowly because the loss surface is elongated: the gradient oscillates across the narrow direction while making slow progress along the wide direction.

Stochastic Gradient Descent (SGD)

SGD approximates the full gradient with a mini-batch $\mathcal{B} \subset \mathcal{D}$ of size $B = |\mathcal{B}|$, sampled uniformly:

θt+1=θtηtgt,gt=1BiBtθ(fθt(xi),yi)\theta_{t+1} = \theta_t - \eta_t \cdot g_t, \quad g_t = \frac{1}{B} \sum_{i \in \mathcal{B}_t} \nabla_\theta \ell(f_{\theta_t}(x_i), y_i)

Properties of the stochastic gradient gtg_t:

  1. Unbiased: EB[gt]=L(θt)\mathbb{E}_{\mathcal{B}}[g_t] = \nabla \mathcal{L}(\theta_t)
  2. Bounded variance: EgtL(θt)2σ2/B\mathbb{E}\|g_t - \nabla \mathcal{L}(\theta_t)\|^2 \leq \sigma^2 / B, where σ2\sigma^2 is the per-sample gradient variance
  3. Computational cost: O(B)O(B) per step instead of O(N)O(N)
For an $L$-smooth function with stochastic gradients having variance bounded by $\sigma^2$, SGD with step size $\eta = c / \sqrt{T}$ satisfies:

1Tt=0T1EL(θt)2O ⁣(σT+1T)\frac{1}{T}\sum_{t=0}^{T-1} \mathbb{E}\|\nabla \mathcal{L}(\theta_t)\|^2 \leq O\!\left(\frac{\sigma}{\sqrt{T}} + \frac{1}{T}\right)

This is an O(1/T)O(1/\sqrt{T}) rate to a stationary point (where L0\nabla \mathcal{L} \approx 0). Note: for non-convex functions, a stationary point may be a saddle point or local minimum, not necessarily a global minimum.

**SGD noise as implicit regularization.** The stochastic noise in mini-batch gradients serves as implicit regularization. Empirical and theoretical evidence suggests:
  1. SGD noise helps escape sharp minima (high curvature) in favor of flat minima (low curvature), which tend to generalize better ([?keskar2017large]).
  2. The noise scale is proportional to η/B\eta / B (learning rate divided by batch size), which is why the linear scaling rule ([?goyal2017accurate]) prescribes scaling η\eta proportionally to BB.
  3. Large-batch training (small noise) can converge to sharp minima that generalize poorly, unless carefully tuned with warmup and learning rate schedules.

Momentum

SGD oscillates in directions with high curvature (large eigenvalues of the Hessian) while making slow progress along directions with low curvature. Momentum fixes this by accumulating a velocity vector:

vt+1=μvt+gt,θt+1=θtηvt+1v_{t+1} = \mu \, v_t + g_t, \qquad \theta_{t+1} = \theta_t - \eta \, v_{t+1}

where μ[0,1)\mu \in [0, 1) is the momentum coefficient (typically μ=0.9\mu = 0.9). The velocity vtv_t is an exponential moving average of past gradients with effective window size 1/(1μ)\sim 1/(1-\mu).

**Physical analogy.** Think of a ball rolling down the loss surface. Without momentum, the ball stops immediately when the slope changes direction (oscillation). With momentum, the ball has inertia: it accelerates along consistent downhill directions and dampens oscillations. On a quadratic with condition number $\kappa$, optimal momentum $\mu^* = (\sqrt{\kappa} - 1)/(\sqrt{\kappa} + 1)$ improves convergence from $O(\kappa \log 1/\epsilon)$ to $O(\sqrt{\kappa} \log 1/\epsilon)$ iterations.

Nesterov accelerated gradient evaluates the gradient at the "look-ahead" position θtημvt\theta_t - \eta \mu v_t:

vt+1=μvt+L(θtημvt),θt+1=θtηvt+1v_{t+1} = \mu \, v_t + \nabla \mathcal{L}(\theta_t - \eta \mu v_t), \qquad \theta_{t+1} = \theta_t - \eta \, v_{t+1}

Nesterov momentum achieves the optimal O(1/T2)O(1/T^2) convergence rate for smooth convex functions, compared to O(1/T)O(1/T) for vanilla GD ([?nesterov1983method]). In practice, the improvement over classical momentum is often modest for deep learning.

Adam (Adaptive Moment Estimation)

**Adam** [@kingma2015adam] combines momentum with per-parameter adaptive learning rates by tracking exponential moving averages of the first and second moments of the gradient:

mt=β1mt1+(1β1)gt(first moment / mean)m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \qquad \text{(first moment / mean)}

vt=β2vt1+(1β2)gt2(second moment / uncentered variance)v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \qquad \text{(second moment / uncentered variance)}

Bias-corrected estimates (since m0=v0=0m_0 = v_0 = 0):

m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \qquad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}

Parameter update:

θt+1=θtηm^tv^t+ϵ\theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

Default hyperparameters: β1=0.9\beta_1 = 0.9, β2=0.999\beta_2 = 0.999, ϵ=108\epsilon = 10^{-8}.

**Why Adam works.** The effective learning rate for parameter $\theta_i$ at step $t$ is $\eta / (\sqrt{\hat{v}_{t,i}} + \epsilon)$:
  • Parameters with large gradient variance (noisy signal) have large v^t,i\hat{v}_{t,i}, so their effective learning rate is smaller -- Adam is cautious where the signal is noisy.
  • Parameters with small gradient variance (consistent signal) have small v^t,i\hat{v}_{t,i}, so their effective learning rate is larger -- Adam makes confident updates.
  • The bias correction is essential in the first few steps: without it, mtm_t and vtv_t are biased toward zero because they are initialized at zero and the exponential average has not converged.
**AdamW (decoupled weight decay)** [@loshchilov2019adamw]. In standard Adam, L2 regularization adds $\lambda \theta$ to the gradient, which then gets divided by $\sqrt{\hat{v}_t}$. This means the effective regularization strength varies per parameter -- defeating the purpose of uniform weight decay. AdamW fixes this by applying weight decay *directly* to the parameters, outside the adaptive learning rate:

θt+1=(1ηλ)θtηm^tv^t+ϵ\theta_{t+1} = (1 - \eta \lambda) \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

AdamW is the standard optimizer for transformer training. Typical hyperparameters: η[105,103]\eta \in [10^{-5}, 10^{-3}], β1=0.9\beta_1 = 0.9, β2[0.95,0.999]\beta_2 \in [0.95, 0.999], weight decay λ[0.01,0.1]\lambda \in [0.01, 0.1].

Other Optimizers

OptimizerUpdate Rule (Simplified)MemoryKey Property
SGDθθηg\theta \leftarrow \theta - \eta gO(P)O(P)Simple, good generalization
SGD + Momentumvμv+gv \leftarrow \mu v + g; θθηv\theta \leftarrow \theta - \eta vO(2P)O(2P)Dampens oscillations
AdamUses mt,vtm_t, v_t as aboveO(3P)O(3P)Adaptive per-parameter rates
AdamWAdam + decoupled weight decayO(3P)O(3P)Standard for transformers
AdafactorFactored second momentsO(P+rows+cols)O(P + \text{rows} + \text{cols})Memory-efficient for large matrices
LAMBAdam + layerwise LR scalingO(3P)O(3P)Large-batch training
LionSign-based momentum updateO(2P)O(2P)Memory-efficient, empirically strong
MuonOrthogonalized momentumO(2P)O(2P)Emerging; strong empirical results

Learning Rate Schedules

The learning rate ηt\eta_t is rarely constant. A schedule adapts it over training:

ScheduleFormulaTypical Use
Constantηt=η0\eta_t = \eta_0Baselines, short fine-tuning
Step decayηt=η0γt/s\eta_t = \eta_0 \cdot \gamma^{\lfloor t/s \rfloor}CNNs (ResNet); γ=0.1\gamma = 0.1, ss at 30/60/90 epochs
Cosine annealingηt=ηmin+12(η0ηmin)(1+cos(πt/T))\eta_t = \eta_{\min} + \frac{1}{2}(\eta_0 - \eta_{\min})(1 + \cos(\pi t / T))Transformer pretraining ([?loshchilov2017sgdr])
Linear warmup + cosineηt=η0min(t/w,  12(1+cos(π(tw)/(Tw))))\eta_t = \eta_0 \cdot \min(t/w, \; \frac{1}{2}(1+\cos(\pi(t-w)/(T-w))))Standard for LLM pretraining
Warmup + inverse sqrtηt=η0min(t/w,  w/t)\eta_t = \eta_0 \cdot \min(t/w, \; \sqrt{w/t})Original Transformer (Vaswani et al., 2017)
WSD (Warmup-Stable-Decay)Warmup \to constant \to cosine decayPractical for unknown training length
**Why warmup?** In the first few steps, Adam's second-moment estimates $\hat{v}_t$ are based on very few samples and are unreliable. Large learning rates during this phase cause wild updates. Linear warmup (ramping $\eta$ from 0 to $\eta_0$ over $w$ steps) gives the adaptive estimates time to stabilize. Empirically, $w$ = 1--5% of total training steps works well.

Gradient Clipping

To prevent exploding gradients, clip the gradient norm before applying the update:

gtgtmin ⁣(1,  cgt)g_t \leftarrow g_t \cdot \min\!\left(1, \; \frac{c}{\|g_t\|}\right)

where cc is the maximum allowed gradient norm (typically c=1.0c = 1.0). This preserves the gradient direction but bounds its magnitude.

Gradient clipping is essential for training transformers and RNNs. Without it, a single batch with an unusually large loss can produce a gradient spike that destabilizes all the optimizer state (Adam's running averages get corrupted). A typical setting is `max_grad_norm=1.0` in both PyTorch and JAX training loops.

Convergence Rate Summary

SettingAlgorithmRateMetric
LL-smooth convexGD (η=1/L\eta = 1/L)O(1/T)O(1/T)L(θT)L\mathcal{L}(\theta_T) - \mathcal{L}^*
LL-smooth, μ\mu-strongly convexGD (η=1/L\eta = 1/L)O((1μ/L)T)O((1-\mu/L)^T)L(θT)L\mathcal{L}(\theta_T) - \mathcal{L}^*
LL-smooth convexNesterovO(1/T2)O(1/T^2)L(θT)L\mathcal{L}(\theta_T) - \mathcal{L}^*
LL-smooth convexSGD (η1/T\eta \propto 1/\sqrt{T})O(1/T)O(1/\sqrt{T})E[L(θT)L]\mathbb{E}[\mathcal{L}(\theta_T) - \mathcal{L}^*]
LL-smooth non-convexSGD (η1/T\eta \propto 1/\sqrt{T})O(1/T)O(1/\sqrt{T})1TEL2\frac{1}{T}\sum \mathbb{E}\|\nabla \mathcal{L}\|^2
These rates are worst-case. In practice, deep learning loss surfaces have favorable structure (saddle points are easily escaped, local minima are often near-global) that makes convergence faster than theoretical bounds suggest. The rates are most useful for comparing algorithms and understanding scaling behavior (e.g., doubling the number of SGD steps gives $\sqrt{2}\times$ improvement).

Notation Summary

SymbolMeaning
θ\thetaModel parameters
η,ηt\eta, \eta_tLearning rate (possibly time-varying)
L\mathcal{L}Loss function (empirical risk)
\ellPer-sample loss
gtg_tGradient (or stochastic gradient) at step tt
vtv_tVelocity (momentum) or second moment (Adam)
mtm_tFirst moment estimate (Adam)
m^t,v^t\hat{m}_t, \hat{v}_tBias-corrected moment estimates
β1,β2\beta_1, \beta_2Exponential decay rates (Adam)
μ\muMomentum coefficient
LLLipschitz constant of the gradient (LL-smoothness)
μ\mu (strong convexity)Strong convexity parameter
κ=L/μ\kappa = L/\muCondition number
$B =\mathcal{B}
σ2\sigma^2Per-sample gradient variance
ccGradient clipping threshold

References