Skip to main content

Non-Convex Optimization

Neural network loss functions are non-convex in their parameters. Despite the lack of convexity guarantees, stochastic gradient descent reliably finds good solutions. Understanding why requires studying the geometry of the loss landscape and the implicit biases of optimization algorithms.

The Non-Convex Landscape of Deep Learning

A network with PP parameters defines a loss surface L:RPR\mathcal{L}: \mathbb{R}^P \to \mathbb{R}. This surface has an astronomical number of critical points (where L=0\nabla \mathcal{L} = 0), and most of them are saddle points, not minima.

Saddle Points vs. Local Minima

A critical point $\theta^*$ (where $\nabla \mathcal{L}(\theta^*) = 0$) is a **saddle point** if the Hessian $H = \nabla^2 \mathcal{L}(\theta^*)$ has both positive and negative eigenvalues. The number of negative eigenvalues is called the **index** of the saddle point. **Saddle points dominate in high dimensions.** Being a local minimum requires *every* Hessian eigenvalue to be positive. For the Hessian of a high-dimensional random landscape, the eigenvalues are strongly correlated (their joint distribution resembles a random-matrix ensemble, not independent coin flips), so this is a rare large-deviation event. The probability that all $P$ eigenvalues are positive decays like

P(local minimum)ecP2P(\text{local minimum}) \sim e^{-c P^2}

for a positive constant cc: the exponent grows with P2P^2, not PP. The practical takeaway is the same and is supported empirically in deep networks: the overwhelming majority of high-loss critical points are saddle points rather than local minima, and minima become exponentially rarer as dimension grows (Dauphin et al., 2014).

The practical implication: **the main obstacle to optimization is not bad local minima but saddle points and plateaus.** SGD's stochastic noise naturally helps escape saddle points because the noise has a component along the negative curvature directions, pushing the iterate away from the saddle. Formal results show that perturbed gradient descent escapes strict saddle points in polynomial time [@jin2017escape].

Loss Surface Geometry

Empirical and theoretical studies have revealed rich structure in neural network loss landscapes:

PropertyObservationImplication
Flat vs. sharp minimaSGD converges to wide, flat valleys; large-batch GD finds sharper minima (Keskar et al., 2017)Flat minima generalize better (robust to weight perturbation)
Mode connectivityGood local minima are connected by paths with low loss (Draxler et al., 2018; [?garipov2018loss])The "loss landscape" is more like a connected valley than isolated basins
Linear mode connectivityCheckpoints from the same run can be linearly interpolated with monotone loss (Frankle et al., 2020)Simplifies model averaging and ensembling
Edge of stabilityGD with large LR evolves the Hessian so λmax(H)2/η\lambda_{\max}(H) \approx 2/\eta (Cohen et al., 2021)Training self-organizes to the stability boundary
Loss landscape simplificationBatchNorm, skip connections, and overparameterization smooth the landscape (Li et al., 2018)Architecture design = implicit optimization design
Progressive sharpeningTraining loss surface becomes sharper over training (Jastrzebski et al., 2020)Motivates learning rate decay
**Edge of stability.** Classical optimization theory says GD converges only if $\eta < 2/\lambda_{\max}(H)$. Cohen et al. (2021) showed that in practice, GD first increases $\lambda_{\max}$ until it reaches $\approx 2/\eta$, then oscillates at this boundary while still decreasing the loss monotonically (in a time-averaged sense). This "edge of stability" phenomenon is not predicted by any existing convergence theory and suggests that neural network training dynamics are qualitatively different from classical optimization. **Computing the edge-of-stability threshold.** Suppose we train with a fixed learning rate $\eta = 0.01$. Classical theory guarantees descent only while $\eta < 2/\lambda_{\max}(H)$, equivalently while

λmax(H)<2η=20.01=200.\lambda_{\max}(H) < \frac{2}{\eta} = \frac{2}{0.01} = 200.

Step 1: At initialization the sharpness is small, say λmax(H)=50\lambda_{\max}(H) = 50. Then ηλmax=0.01×50=0.5<2\eta \lambda_{\max} = 0.01 \times 50 = 0.5 < 2, so we are comfortably inside the stable regime and the loss decreases monotonically.

Step 2: Over training, progressive sharpening drives λmax\lambda_{\max} upward. It does not blow up past 200200; instead it climbs toward the threshold 2/η=2002/\eta = 200 and then hovers there. At λmax=200\lambda_{\max} = 200 we have ηλmax=2\eta \lambda_{\max} = 2, exactly the stability boundary.

Step 3: At the boundary the loss no longer decreases at every step. It oscillates from step to step yet still trends down when averaged over a window. If we now halved the learning rate to η=0.005\eta = 0.005, the new threshold would be 2/η=4002/\eta = 400, so the same λmax=200\lambda_{\max} = 200 would again satisfy ηλmax=1<2\eta \lambda_{\max} = 1 < 2, returning us to the stable regime until sharpening catches up to the new boundary.

**Why flat minima generalize.** A minimum is "flat" if the loss surface curves gently around it: small perturbations to $\theta$ cause small changes in $\mathcal{L}$. Formally, if $\theta^*$ is in a flat minimum, then for a random perturbation $\delta$ with $\|\delta\| = \epsilon$:

L(θ+δ)L(θ)12ϵ2λmax(H)|\mathcal{L}(\theta^* + \delta) - \mathcal{L}(\theta^*)| \leq \frac{1}{2}\epsilon^2 \lambda_{\max}(H)

A smaller λmax\lambda_{\max} (flatter) means the loss is robust to the weight perturbations that occur when moving from training to test data. This is formalized by PAC-Bayesian generalization bounds (McAllester, 1999), which relate generalization to the volume of parameter space around θ\theta^* with low loss.

Why Deep Networks Train Despite Non-Convexity

Several factors explain the empirical success:

FactorMechanismEvidence
OverparameterizationPNP \gg N creates many global minima; gradient descent reaches one (Du et al., 2019)Networks with more parameters are easier to train
SGD noiseStochastic gradients escape saddle points and sharp minimaSmall-batch training generalizes better (Keskar et al., 2017)
Skip connectionsResidual connections create smooth loss surfacesResNets train where plain networks fail (He et al., 2016)
NormalizationBatchNorm/LayerNorm smooth the loss landscapeEnables much larger learning rates (Ioffe & Szegedy, 2015; [?santurkar2018does])
InitializationHe/Xavier initialization places θ0\theta_0 in a well-behaved regionBad initialization causes vanishing/exploding gradients
Implicit regularizationSGD's trajectory has an inductive bias toward simple solutionsMinimum-norm solutions in linear models (Gunasekar et al., 2018)
Progressive learningWarmup + LR decay matches training phases to landscape geometryStandard in all large-scale training
**Overparameterization regimes.** In the extreme overparameterization limit, the **Neural Tangent Kernel (NTK)** regime [@jacot2018neural], the network behaves like a linear model in a kernel space defined by the initialization. Training is approximately convex, and the dynamics are well-understood. Real networks operate in a "rich" or "feature learning" regime that is harder to analyze but achieves better performance.

Escaping Bad Regions

StrategyHow It WorksUsed In
SGD noiseGradient noise provides random perturbationStandard training
Learning rate warmupGradually increases LR to avoid early sharp minimaTransformer pretraining
Cyclic LR / restartsPeriodic LR increases to escape local basinsSGDR (Loshchilov & Hutter, 2017)
Gradient clippingBounds gradient magnitude to prevent overshootingRNNs, transformers
Sharpness-aware minimization (SAM)Minimizes worst-case loss in a neighborhoodImproves generalization (Foret et al., 2021)
Stochastic weight averaging (SWA)Averages weights along the trajectoryBetter generalization, flatter minima
Exponential moving average (EMA)Maintains running average of weightsStandard in modern training

Learning Rate Schedules

The learning rate is the single most important hyperparameter. It controls the noise scale (η/B\propto \eta / B), the convergence speed, and the type of minimum found.

Warmup starts with a small LR and linearly increases to the target:

ηt=ηmaxmin(tw,1)for tw\eta_t = \eta_{\max} \cdot \min\left(\frac{t}{w}, \, 1\right) \quad \text{for } t \leq w

Cosine decay smoothly decreases the LR:

ηt=ηmin+12(ηmaxηmin)(1+cos(π(tw)Tw))for t>w\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{\pi (t - w)}{T - w}\right)\right) \quad \text{for } t > w

**Evaluating a warmup + cosine schedule.** Take $\eta_{\max} = 10^{-3}$, $\eta_{\min} = 10^{-5}$, warmup $w = 2000$ steps, and total $T = 100{,}000$ steps. We compute $\eta_t$ at three checkpoints.

Step 1 (inside warmup, t=1000wt = 1000 \leq w): use the warmup rule ηt=ηmaxmin(t/w,1)\eta_t = \eta_{\max}\cdot\min(t/w, 1). Here t/w=1000/2000=0.5t/w = 1000/2000 = 0.5, so

η1000=103×0.5=5×104.\eta_{1000} = 10^{-3} \times 0.5 = 5\times 10^{-4}.

Step 2 (start of decay, t=w=2000t = w = 2000): the cosine argument is π(tw)Tw=π098,000=0\frac{\pi(t-w)}{T-w} = \frac{\pi \cdot 0}{98{,}000} = 0, and cos(0)=1\cos(0) = 1, so

η2000=105+12(103105)(1+1)=105+(103105)=103.\eta_{2000} = 10^{-5} + \tfrac{1}{2}(10^{-3} - 10^{-5})(1 + 1) = 10^{-5} + (10^{-3} - 10^{-5}) = 10^{-3}.

The schedule is continuous: the cosine branch starts exactly at ηmax\eta_{\max} where warmup ended.

Step 3 (midpoint of decay, t=51,000t = 51{,}000): now tw=49,000t - w = 49{,}000 and Tw=98,000T - w = 98{,}000, so the argument is π49,00098,000=π2\frac{\pi \cdot 49{,}000}{98{,}000} = \frac{\pi}{2}, and cos(π/2)=0\cos(\pi/2) = 0. Thus

η51,000=105+12(103105)(1+0)=105+12(9.9×104)5.05×104.\eta_{51{,}000} = 10^{-5} + \tfrac{1}{2}(10^{-3} - 10^{-5})(1 + 0) = 10^{-5} + \tfrac{1}{2}(9.9\times 10^{-4}) \approx 5.05\times 10^{-4}.

At the halfway point of the decay phase the learning rate sits roughly halfway (on a linear scale) between ηmax\eta_{\max} and ηmin\eta_{\min}, as the cosine shape dictates. At the very end (t=Tt = T) the argument reaches π\pi, cos(π)=1\cos(\pi) = -1, and ηT=ηmin=105\eta_T = \eta_{\min} = 10^{-5}.

**Three phases of training:**
  1. Warmup phase (t<wt < w): LR ramps from 0 to ηmax\eta_{\max}. Adam's second-moment estimates (v^t\hat{v}_t) are unreliable initially; warmup gives them time to stabilize. The loss drops rapidly during warmup.
  2. Stable phase (w<t<Tdecayw < t < T_{\text{decay}}): LR is near ηmax\eta_{\max}. The model explores the loss landscape, learning features and escaping shallow minima. Most learning happens here.
  3. Decay phase (t>Tdecayt > T_{\text{decay}}): LR decreases toward ηmin\eta_{\min}. The optimizer refines its solution, settling into a flat minimum. The loss decreases slowly but generalization improves.

The WSD (Warmup-Stable-Decay) schedule makes this explicit (Hu et al., 2024): a short warmup, a long stable phase at roughly constant LR, then a final decay phase (typically the last 10% to 20% of training) that anneals the LR down. As rough illustrative splits, the warmup might take a few percent of training, the stable phase the bulk of it, and the decay the final stretch. This is increasingly popular because the stable phase can be extended indefinitely, so it does not require knowing the total training length in advance.

ScheduleFormulaBest ForKey Parameter
Constantηt=η0\eta_t = \eta_0Fine-tuning, short runsη0\eta_0
Step decayηt=η0γt/s\eta_t = \eta_0 \gamma^{\lfloor t/s \rfloor}CNNs (ResNet)γ=0.1\gamma = 0.1, steps at 30/60/90 epochs
Cosine annealing12η0(1+cos(πt/T))\frac{1}{2}\eta_0(1 + \cos(\pi t/T))PretrainingMin LR, TT
Warmup + cosineLinear warmup then cosineLLM pretrainingWarmup steps ww
WSDWarmup \to constant \to cosineFlexible pretrainingPhase boundaries
1-cycleRamp up then downFast convergenceMax LR, TT
Inverse sqrtη0/t\eta_0 / \sqrt{t}Original Transformerη0\eta_0

Notation Summary

SymbolMeaning
PPNumber of model parameters
λmax(H)\lambda_{\max}(H)Largest eigenvalue of the Hessian
η,ηt\eta, \eta_tLearning rate (possibly time-varying)
wwNumber of warmup steps
TTTotal training steps
ηmin,ηmax\eta_{\min}, \eta_{\max}Minimum and maximum learning rate
γ\gammaDecay factor for step schedule
SAMSharpness-Aware Minimization
SWAStochastic Weight Averaging
EMAExponential Moving Average
NTKNeural Tangent Kernel

References