Non-Convex Optimization
Neural network loss functions are non-convex in their parameters. Despite the lack of convexity guarantees, stochastic gradient descent reliably finds good solutions. Understanding why requires studying the geometry of the loss landscape and the implicit biases of optimization algorithms.
The Non-Convex Landscape of Deep Learning
A network with parameters defines a loss surface . This surface has an astronomical number of critical points (where ), and most of them are saddle points, not minima.
Saddle Points vs. Local Minima
since each Hessian eigenvalue must be positive independently. For (a 100M parameter model), this probability is less than .
Loss Surface Geometry
Empirical and theoretical studies have revealed rich structure in neural network loss landscapes:
| Property | Observation | Implication |
|---|---|---|
| Flat vs. sharp minima | SGD converges to wide, flat valleys; large-batch GD finds sharper minima ([?keskar2017large]) | Flat minima generalize better (robust to weight perturbation) |
| Mode connectivity | Good local minima are connected by paths with low loss ([?draxler2018essentially]; [?garipov2018loss]) | The "loss landscape" is more like a connected valley than isolated basins |
| Linear mode connectivity | Checkpoints from the same run can be linearly interpolated with monotone loss ([?frankle2020linear]) | Simplifies model averaging and ensembling |
| Edge of stability | GD with large LR evolves the Hessian so ([?cohen2021gradient]) | Training self-organizes to the stability boundary |
| Loss landscape simplification | BatchNorm, skip connections, and overparameterization smooth the landscape ([?li2018visualizing]) | Architecture design = implicit optimization design |
| Progressive sharpening | Training loss surface becomes sharper over training ([?jastrzebski2020break]) | Motivates learning rate decay |
A smaller (flatter) means the loss is robust to the weight perturbations that occur when moving from training to test data. This is formalized by PAC-Bayesian generalization bounds ([?mcallester1999pac]), which relate generalization to the volume of parameter space around with low loss.
Why Deep Networks Train Despite Non-Convexity
Several factors explain the empirical success:
| Factor | Mechanism | Evidence |
|---|---|---|
| Overparameterization | creates many global minima; gradient descent reaches one ([?du2019gradient]) | Networks with more parameters are easier to train |
| SGD noise | Stochastic gradients escape saddle points and sharp minima | Small-batch training generalizes better ([?keskar2017large]) |
| Skip connections | Residual connections create smooth loss surfaces | ResNets train where plain networks fail (He et al., 2016) |
| Normalization | BatchNorm/LayerNorm smooth the loss landscape | Enables much larger learning rates ([?ioffe2015batch]; [?santurkar2018does]) |
| Initialization | He/Xavier initialization places in a well-behaved region | Bad initialization causes vanishing/exploding gradients |
| Implicit regularization | SGD's trajectory has an inductive bias toward simple solutions | Minimum-norm solutions in linear models ([?gunasekar2018characterizing]) |
| Progressive learning | Warmup + LR decay matches training phases to landscape geometry | Standard in all large-scale training |
Escaping Bad Regions
| Strategy | How It Works | Used In |
|---|---|---|
| SGD noise | Gradient noise provides random perturbation | Standard training |
| Learning rate warmup | Gradually increases LR to avoid early sharp minima | Transformer pretraining |
| Cyclic LR / restarts | Periodic LR increases to escape local basins | SGDR ([?loshchilov2017sgdr]) |
| Gradient clipping | Bounds gradient magnitude to prevent overshooting | RNNs, transformers |
| Sharpness-aware minimization (SAM) | Minimizes worst-case loss in a neighborhood | Improves generalization ([?foret2021sam]) |
| Stochastic weight averaging (SWA) | Averages weights along the trajectory | Better generalization, flatter minima |
| Exponential moving average (EMA) | Maintains running average of weights | Standard in modern training |
Learning Rate Schedules
The learning rate is the single most important hyperparameter. It controls the noise scale (), the convergence speed, and the type of minimum found.
Warmup starts with a small LR and linearly increases to the target:
Cosine decay smoothly decreases the LR:
- Warmup phase (): LR ramps from 0 to . Adam's second-moment estimates () are unreliable initially; warmup gives them time to stabilize. The loss drops rapidly during warmup.
- Stable phase (): LR is near . The model explores the loss landscape, learning features and escaping shallow minima. Most learning happens here.
- Decay phase (): LR decreases toward . The optimizer refines its solution, settling into a flat minimum. The loss decreases slowly but generalization improves.
The WSD (Warmup-Stable-Decay) schedule makes this explicit: warmup for 5% of training, constant LR for ~75%, then cosine decay for the final ~20%. This is increasingly popular because it does not require knowing the total training length in advance.
| Schedule | Formula | Best For | Key Parameter |
|---|---|---|---|
| Constant | Fine-tuning, short runs | ||
| Step decay | CNNs (ResNet) | , steps at 30/60/90 epochs | |
| Cosine annealing | Pretraining | Min LR, | |
| Warmup + cosine | Linear warmup then cosine | LLM pretraining | Warmup steps |
| WSD | Warmup constant cosine | Flexible pretraining | Phase boundaries |
| 1-cycle | Ramp up then down | Fast convergence | Max LR, |
| Inverse sqrt | Original Transformer |
Notation Summary
| Symbol | Meaning |
|---|---|
| Number of model parameters | |
| Largest eigenvalue of the Hessian | |
| Learning rate (possibly time-varying) | |
| Number of warmup steps | |
| Total training steps | |
| Minimum and maximum learning rate | |
| Decay factor for step schedule | |
| Condition number | |
| SAM | Sharpness-Aware Minimization |
| SWA | Stochastic Weight Averaging |
| EMA | Exponential Moving Average |
| NTK | Neural Tangent Kernel |