Non-Convex Optimization
Neural network loss functions are non-convex in their parameters. Despite the lack of convexity guarantees, stochastic gradient descent reliably finds good solutions. Understanding why requires studying the geometry of the loss landscape and the implicit biases of optimization algorithms.
The Non-Convex Landscape of Deep Learning
A network with parameters defines a loss surface . This surface has an astronomical number of critical points (where ), and most of them are saddle points, not minima.
Saddle Points vs. Local Minima
for a positive constant : the exponent grows with , not . The practical takeaway is the same and is supported empirically in deep networks: the overwhelming majority of high-loss critical points are saddle points rather than local minima, and minima become exponentially rarer as dimension grows (Dauphin et al., 2014).
Loss Surface Geometry
Empirical and theoretical studies have revealed rich structure in neural network loss landscapes:
| Property | Observation | Implication |
|---|---|---|
| Flat vs. sharp minima | SGD converges to wide, flat valleys; large-batch GD finds sharper minima (Keskar et al., 2017) | Flat minima generalize better (robust to weight perturbation) |
| Mode connectivity | Good local minima are connected by paths with low loss (Draxler et al., 2018; [?garipov2018loss]) | The "loss landscape" is more like a connected valley than isolated basins |
| Linear mode connectivity | Checkpoints from the same run can be linearly interpolated with monotone loss (Frankle et al., 2020) | Simplifies model averaging and ensembling |
| Edge of stability | GD with large LR evolves the Hessian so (Cohen et al., 2021) | Training self-organizes to the stability boundary |
| Loss landscape simplification | BatchNorm, skip connections, and overparameterization smooth the landscape (Li et al., 2018) | Architecture design = implicit optimization design |
| Progressive sharpening | Training loss surface becomes sharper over training (Jastrzebski et al., 2020) | Motivates learning rate decay |
Step 1: At initialization the sharpness is small, say . Then , so we are comfortably inside the stable regime and the loss decreases monotonically.
Step 2: Over training, progressive sharpening drives upward. It does not blow up past ; instead it climbs toward the threshold and then hovers there. At we have , exactly the stability boundary.
Step 3: At the boundary the loss no longer decreases at every step. It oscillates from step to step yet still trends down when averaged over a window. If we now halved the learning rate to , the new threshold would be , so the same would again satisfy , returning us to the stable regime until sharpening catches up to the new boundary.
A smaller (flatter) means the loss is robust to the weight perturbations that occur when moving from training to test data. This is formalized by PAC-Bayesian generalization bounds (McAllester, 1999), which relate generalization to the volume of parameter space around with low loss.
Why Deep Networks Train Despite Non-Convexity
Several factors explain the empirical success:
| Factor | Mechanism | Evidence |
|---|---|---|
| Overparameterization | creates many global minima; gradient descent reaches one (Du et al., 2019) | Networks with more parameters are easier to train |
| SGD noise | Stochastic gradients escape saddle points and sharp minima | Small-batch training generalizes better (Keskar et al., 2017) |
| Skip connections | Residual connections create smooth loss surfaces | ResNets train where plain networks fail (He et al., 2016) |
| Normalization | BatchNorm/LayerNorm smooth the loss landscape | Enables much larger learning rates (Ioffe & Szegedy, 2015; [?santurkar2018does]) |
| Initialization | He/Xavier initialization places in a well-behaved region | Bad initialization causes vanishing/exploding gradients |
| Implicit regularization | SGD's trajectory has an inductive bias toward simple solutions | Minimum-norm solutions in linear models (Gunasekar et al., 2018) |
| Progressive learning | Warmup + LR decay matches training phases to landscape geometry | Standard in all large-scale training |
Escaping Bad Regions
| Strategy | How It Works | Used In |
|---|---|---|
| SGD noise | Gradient noise provides random perturbation | Standard training |
| Learning rate warmup | Gradually increases LR to avoid early sharp minima | Transformer pretraining |
| Cyclic LR / restarts | Periodic LR increases to escape local basins | SGDR (Loshchilov & Hutter, 2017) |
| Gradient clipping | Bounds gradient magnitude to prevent overshooting | RNNs, transformers |
| Sharpness-aware minimization (SAM) | Minimizes worst-case loss in a neighborhood | Improves generalization (Foret et al., 2021) |
| Stochastic weight averaging (SWA) | Averages weights along the trajectory | Better generalization, flatter minima |
| Exponential moving average (EMA) | Maintains running average of weights | Standard in modern training |
Learning Rate Schedules
The learning rate is the single most important hyperparameter. It controls the noise scale (), the convergence speed, and the type of minimum found.
Warmup starts with a small LR and linearly increases to the target:
Cosine decay smoothly decreases the LR:
Step 1 (inside warmup, ): use the warmup rule . Here , so
Step 2 (start of decay, ): the cosine argument is , and , so
The schedule is continuous: the cosine branch starts exactly at where warmup ended.
Step 3 (midpoint of decay, ): now and , so the argument is , and . Thus
At the halfway point of the decay phase the learning rate sits roughly halfway (on a linear scale) between and , as the cosine shape dictates. At the very end () the argument reaches , , and .
- Warmup phase (): LR ramps from 0 to . Adam's second-moment estimates () are unreliable initially; warmup gives them time to stabilize. The loss drops rapidly during warmup.
- Stable phase (): LR is near . The model explores the loss landscape, learning features and escaping shallow minima. Most learning happens here.
- Decay phase (): LR decreases toward . The optimizer refines its solution, settling into a flat minimum. The loss decreases slowly but generalization improves.
The WSD (Warmup-Stable-Decay) schedule makes this explicit (Hu et al., 2024): a short warmup, a long stable phase at roughly constant LR, then a final decay phase (typically the last 10% to 20% of training) that anneals the LR down. As rough illustrative splits, the warmup might take a few percent of training, the stable phase the bulk of it, and the decay the final stretch. This is increasingly popular because the stable phase can be extended indefinitely, so it does not require knowing the total training length in advance.
| Schedule | Formula | Best For | Key Parameter |
|---|---|---|---|
| Constant | Fine-tuning, short runs | ||
| Step decay | CNNs (ResNet) | , steps at 30/60/90 epochs | |
| Cosine annealing | Pretraining | Min LR, | |
| Warmup + cosine | Linear warmup then cosine | LLM pretraining | Warmup steps |
| WSD | Warmup constant cosine | Flexible pretraining | Phase boundaries |
| 1-cycle | Ramp up then down | Fast convergence | Max LR, |
| Inverse sqrt | Original Transformer |
Notation Summary
| Symbol | Meaning |
|---|---|
| Number of model parameters | |
| Largest eigenvalue of the Hessian | |
| Learning rate (possibly time-varying) | |
| Number of warmup steps | |
| Total training steps | |
| Minimum and maximum learning rate | |
| Decay factor for step schedule | |
| SAM | Sharpness-Aware Minimization |
| SWA | Stochastic Weight Averaging |
| EMA | Exponential Moving Average |
| NTK | Neural Tangent Kernel |
References
- Jeremy M. Cohen, Simran Kaur, Yuanzhi Li, J. Zico Kolter, Ameet Talwalkar (2021). Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability. ICLR. ↗
- Yann N. Dauphin, Razvan Pascanu, Caglar Gulcehre, Kyunghyun Cho, Surya Ganguli, Yoshua Bengio (2014). Identifying and attacking the saddle point problem in high-dimensional non-convex optimization. Advances in Neural Information Processing Systems (NeurIPS). ↗
- Felix Draxler, Kambis Veschgini, Manfred Salmhofer, Fred Hamprecht (2018). Essentially No Barriers in Neural Network Energy Landscape. ICML. ↗
- Simon S. Du, Xiyu Zhai, Barnabas Poczos, Aarti Singh (2019). Gradient Descent Provably Optimizes Over-parameterized Neural Networks. ICLR. ↗
- Pierre Foret, Ariel Kleiner, Hossein Mobahi, Behnam Neyshabur (2021). Sharpness-Aware Minimization for Efficiently Improving Generalization. ICLR. ↗
- Jonathan Frankle, Gintare Karolina Dziugaite, Daniel M. Roy, Michael Carbin (2020). Linear Mode Connectivity and the Lottery Ticket Hypothesis. ICML. ↗
- Suriya Gunasekar, Jason D. Lee, Daniel Soudry, Nathan Srebro (2018). Characterizing Implicit Bias in Terms of Optimization Geometry. ICML. ↗
- Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun (2016). Deep Residual Learning for Image Recognition. CVPR. ↗
- Shengding Hu, Yuge Tu, Xu Han, Chaoqun He, Ganqu Cui, Xiang Long, Zhi Zheng, Yewei Fang, Yuxiang Huang, Weilin Zhao, Xinrong Zhang, Zheng Leng Thai, Kaihuo Zhang, Chongyi Wang, Yuan Yao, Chenyang Zhao, Jie Zhou, Jie Cai, Zhongwu Zhai, Ning Ding, Chao Jia, Guoyang Zeng, Dahai Li, Zhiyuan Liu, Maosong Sun (2024). MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies. arXiv preprint arXiv:2404.06395. ↗
- Sergey Ioffe, Christian Szegedy (2015). Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. ICML. ↗
- Stanislaw Jastrzebski, Maciej Szymczak, Stanislav Fort, Devansh Arpit, Jacek Tabor, Kyunghyun Cho, Krzysztof Geras (2020). The Break-Even Point on Optimization Trajectories of Deep Neural Networks. ICLR. ↗
- Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, Ping Tak Peter Tang (2017). On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. ICLR. ↗
- Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer, Tom Goldstein (2018). Visualizing the Loss Landscape of Neural Nets. NeurIPS. ↗
- Ilya Loshchilov, Frank Hutter (2017). SGDR: Stochastic Gradient Descent with Warm Restarts. ICLR. ↗
- David A. McAllester (1999). PAC-Bayesian Model Averaging. COLT. ↗