Gradient Descent — Deep Dive
The Math Behind Every AI Model You’ve Used
Gradient descent is ~160 years old. Augustin-Louis Cauchy described the method of steepest descent in 1847. What’s remarkable isn’t the algorithm — it’s that this 19th-century calculus technique, applied to neural networks via backpropagation, powers systems that generate human language, diagnose cancer, and play StarCraft at superhuman levels.
This article assumes you understand loss functions and basic neural networks. If not, read those first.
Backpropagation: How Gradients Actually Get Computed
You can’t apply gradient descent to a neural network without first computing the gradient — which means computing how much each of the model’s potentially billions of parameters contributed to the current loss.
This is backpropagation, and it’s what made deep learning practical.
The algorithm is an application of the chain rule from calculus. For a network with layers $L_1, L_2, …, L_n$, the gradient of the loss $\mathcal{L}$ with respect to weights in layer $L_1$ is computed by multiplying partial derivatives through every layer:
$$\frac{\partial \mathcal{L}}{\partial W_1} = \frac{\partial \mathcal{L}}{\partial L_n} \cdot \frac{\partial L_n}{\partial L_{n-1}} \cdot \ldots \cdot \frac{\partial L_2}{\partial L_1} \cdot \frac{\partial L_1}{\partial W_1}$$
The key insight from Rumelhart, Hinton, and Williams’ 1986 paper: you can compute all these partial derivatives in a single backward pass through the network by caching intermediate values from the forward pass. This is $O(n)$ in the number of parameters — linear cost. Without this trick, training deep networks would be computationally infeasible.
Modern frameworks like PyTorch and JAX implement automatic differentiation — they build a computational graph on the fly and differentiate it automatically. You never write backprop code manually.
The Update Rule
The parameter update at each step is:
$$\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t)$$
Where:
- $\theta$ is the parameter vector
- $\eta$ is the learning rate
- $\nabla_\theta \mathcal{L}$ is the gradient of the loss with respect to parameters
Simple. Almost absurdly simple for what it accomplishes.
Momentum: Giving Descent Memory
Plain gradient descent has no memory — each step only knows the current gradient. This makes it susceptible to oscillating in narrow valleys (ravines in the loss landscape where curvature is very different in different directions).
Momentum fixes this by adding a velocity term:
v = beta * v + (1 - beta) * grad
theta = theta - learning_rate * v
The velocity $v$ accumulates gradients over time. With $\beta = 0.9$ (typical), the effective gradient is a weighted average of the last ~10 gradient steps. This smooths oscillations and lets the optimizer move faster through flat regions.
Nesterov Accelerated Gradient (NAG) is a variation that looks ahead — it computes the gradient at the anticipated next position rather than the current one. In theory and practice, this converges faster for convex problems.
Adam: Why Everyone Uses It
Adam (Adaptive Moment Estimation), published by Kingma & Ba in December 2014, has been cited over 200,000 times. It’s the most cited ML paper of all time by a significant margin.
It maintains two running averages per parameter:
- m: first moment (mean of gradients) — like momentum
- v: second moment (mean of squared gradients) — tracks gradient variance
m = beta1 * m + (1 - beta1) * grad # momentum
v = beta2 * v + (1 - beta2) * grad**2 # RMS of gradients
m_hat = m / (1 - beta1**t) # bias correction
v_hat = v / (1 - beta2**t) # bias correction
theta = theta - lr * m_hat / (sqrt(v_hat) + epsilon)
The division by $\sqrt{v}$ is the key innovation. Parameters that receive large, consistent gradients get smaller effective learning rates. Parameters with rare, small gradients get larger effective rates. Each parameter gets its own dynamic learning rate, shaped by its entire gradient history.
Typical defaults: $\beta_1 = 0.9$, $\beta_2 = 0.999$, $\epsilon = 10^{-8}$, $lr = 0.001$. These defaults work surprisingly well across wildly different architectures.
The weakness of Adam: It doesn’t generalize as well as SGD in some settings. A notable 2017 paper by Wilson et al. (“The Marginal Value of Momentum for Small Learning Rate SGD”) showed that SGD with momentum often finds flatter minima that generalize better to test data. This is still debated, but it’s why some production training runs (particularly for image classification on ImageNet) still use SGD with momentum, not Adam.
AdamW: Adam with weight decay decoupled from the gradient update. This is now the default for large language model training. The L2 regularization in standard Adam interacts poorly with the adaptive learning rates — AdamW fixes this by applying weight decay directly to the parameters, not through the gradient.
Learning Rate Schedules
Static learning rates are rarely optimal. Common schedules:
Cosine Annealing: $$\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + \cos(\frac{t\pi}{T}))$$ Starts high, smoothly decreases to minimum. Widely used for LLM training.
Warmup: Start with a very small learning rate for the first N steps, then ramp up. Critical for large models — jumping straight to a large learning rate can cause instability in early training when gradients are chaotic.
1-cycle policy: Increase then decrease. Proposed by Leslie Smith in 2018; empirically allows training with much higher peak learning rates and often reaches better minima faster.
The Chinchilla paper (Hoffmann et al., 2022) — which redefined how everyone thinks about LLM training compute budgets — used a cosine decay schedule with a 1% warmup. This scheduling choice materially affected the results.
The Loss Landscape: What We Actually Know
The intuition of a simple bowl-shaped landscape is wrong for deep networks. Real loss landscapes are high-dimensional (billions of dimensions), non-convex, and poorly understood.
Key findings from research:
Flat minima generalize better. A minimum in a wide, flat valley is more robust than one in a narrow spike — a small perturbation of the weights doesn’t hurt performance much. SGD’s noise effectively prefers flatter minima; Adam sometimes finds sharper ones. This is the root of the Adam vs SGD generalization debate.
Loss of plasticity. Deep networks trained with standard gradient descent can gradually lose the ability to learn new tasks. The gradients become very small for early layers (related to the vanishing gradient problem), and those weights effectively freeze. Techniques like periodic reinitialization of some layers are active research.
Neural collapse. Near the end of training in classification networks, the last-layer features of same-class examples collapse to a single point, and different-class clusters form a maximally-spaced geometric arrangement (simplex ETF). This was discovered empirically in 2020 and is still being understood theoretically.
Grokking. In 2022, Power et al. showed that small transformer models trained on modular arithmetic would first memorize the training data (low training loss, high test loss), and then — if training continued past apparent convergence — suddenly generalize completely. The gradient descent dynamics after apparent “convergence” are still producing useful internal structure. This has significant implications for early stopping heuristics.
Second-Order Methods: Why They’re Not Used
First-order methods (all of the above) only use gradient information — the first derivative. Second-order methods use the Hessian (second derivatives), which captures curvature information and allows for much better step sizing.
Newton’s method update: $$\theta_{t+1} = \theta_t - H^{-1} \nabla \mathcal{L}$$
Where $H$ is the Hessian matrix. This converges in far fewer steps.
The problem: for a model with $n$ parameters, the Hessian is an $n \times n$ matrix. GPT-4 has ~1.8 trillion parameters. The Hessian would have $3.24 \times 10^{24}$ entries. Computing and inverting it is impossible.
K-FAC (Kronecker-Factored Approximate Curvature) approximates the Fisher information matrix using Kronecker products to keep it tractable. It’s been shown to converge in fewer steps than Adam, but the per-step overhead is large enough that wall-clock time savings are modest. Used in some production training at DeepMind.
Sharpness-Aware Minimization (SAM) takes a different approach: before each update, it perturbs the weights to the point of maximum loss in a small neighborhood, then computes the gradient there. This explicitly seeks flat minima. Google Research published this in 2020; it’s used in training some state-of-the-art image models.
Distributed Training Complications
At scale, gradient descent runs across thousands of GPUs. This introduces new problems:
Gradient staleness: In asynchronous distributed training, workers compute gradients on potentially outdated model weights. Synchronous training (waiting for all workers) avoids this but is limited by the slowest worker.
Gradient compression: Communicating full gradient vectors across thousands of GPUs is bandwidth-intensive. Techniques like gradient quantization (1-bit SGD) and error feedback compression reduce communication by 10-100x with minimal accuracy loss.
Large-batch training: Larger batch sizes allow more parallelism, but they change the effective learning dynamics. Linear scaling rule (Goyal et al., 2017): when you multiply batch size by $k$, multiply learning rate by $k$. Works up to a point; very large batches tend to converge to sharper minima.
One Thing to Remember
The reason gradient descent works on trillion-parameter models isn’t magic — it’s the chain rule, applied backwards, in a single pass. But the reason we’re still using first-order methods in 2026 isn’t because they’re optimal; it’s because second-order methods can’t scale to the size of problems we’re solving. The field is still looking for something better.
See Also
- Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'
- Artificial Intelligence What is AI really? Think of it as a dog that learned tricks — impressive, but it doesn't know why it's doing them.
- Bias Variance Tradeoff The fundamental tension in machine learning between being wrong in the same way vs. being wrong in different ways — and why the simplest model isn't always best.
- Deep Learning Why your phone can spot your face in a messy photo album — and why that trick comes from practice, not magic.
- Embeddings How do computers know that 'dog' and 'puppy' mean almost the same thing? They don't read definitions — they turn words into secret map coordinates, and nearby coordinates mean nearby meanings.