Model Pruning — Deep Dive
Optimal Brain Surgeon: Second-Order Pruning
LeCun et al. (1990) derived the Optimal Brain Damage criterion using Taylor series. Hassibi & Stork (1993) extended this to Optimal Brain Surgeon (OBS) — the exact solution.
The loss change when weight $w_q$ is pruned (set to zero):
$$\delta \mathcal{L} = \frac{w_q^2}{2[H^{-1}]_{qq}}$$
Where $H = \nabla^2 \mathcal{L}$ is the full Hessian of the loss. OBS prunes the weight that minimizes $\delta \mathcal{L}$, then updates remaining weights to compensate:
$$\delta \mathbf{w} = -\frac{w_q}{[H^{-1}]_{qq}} \mathbf{H}^{-1} e_q$$
This weight update reduces the loss increase from pruning $w_q$ from $O(w_q^2 / H_{qq}^{-1})$ to 0 (exact, ignoring higher-order terms). The update propagates the pruning’s impact through the entire network.
Practical limitation: Computing and inverting the full Hessian for large networks is $O(n^3)$ — infeasible for millions of parameters. Block-diagonal Hessian approximations (computing $H$ layer by layer) make this tractable. GPTQ uses exactly this approach for quantization (equivalent optimization).
SparseGPT (Frantar & Alistarh, 2023) applies OBS to LLM pruning layer-by-layer, using the layer’s Fisher information matrix as the Hessian approximation and the same Cholesky update trick as GPTQ. This enables 50% pruning of 175B parameter models in ~4 GPU hours — with performance similar to GPTQ quantization at comparable compression.
Movement Pruning: Fine-Tuned Model Considerations
Magnitude pruning works well for pretrained models but is suboptimal for fine-tuned models. The reason: during fine-tuning, weights that matter for the specific task move (from their pretrained values) toward task-useful directions. Weights that were large after pretraining but didn’t move during fine-tuning may matter less than smaller weights that moved a lot.
Movement pruning (Sanh et al., 2020): Score each weight by how much it moved during fine-tuning:
$$\text{importance}(w_i) = w_i \cdot \text{sign}(\nabla_\theta \mathcal{L})$$
Weights moving in the gradient direction (decreasing loss) are important; weights moving against the gradient aren’t. Alternatively:
$$\text{importance}(w_i) = w_i^{T} - w_i^0$$
Where $w_i^T$ is the fine-tuned weight and $w_i^0$ is the pretrained weight.
Results on SQuAD: movement pruning to 97% sparsity retains 92% of F1, outperforming magnitude pruning at the same sparsity by 5+ F1 points.
Soft movement pruning: Makes movement scores differentiable by introducing learned per-weight scores:
$$w_i^{masked} = w_i \cdot \sigma(s_i - \text{threshold})$$
The scores $s_i$ are learned jointly with weights. L1 regularization on $\sigma(s_i)$ drives many scores to 0 (pruned). This allows end-to-end training of the pruning decision.
N:M Sparsity: Hardware-Friendly Unstructured Pruning
NVIDIA’s A100 GPU introduced native acceleration for 2:4 sparsity (also called “semi-structured sparsity”): in every group of 4 consecutive weights, exactly 2 are non-zero. This constraint is tight enough that NVIDIA could build hardware to exploit it.
SparseTensorCore: The Sparse Tensor Core takes a 2:4 sparse matrix, decompresses it with a metadata array, and performs the matrix multiply in 2x fewer operations than the dense equivalent. Result: 2x theoretical speedup over dense FP16 for supported operations.
$$\text{sparse_A} = (A, \text{metadata})$$
Where metadata encodes the positions of the two non-zero elements in each group of 4. The compressed format stores only the non-zero values (50% storage) plus a 2-bit index per non-zero (8% overhead).
Finding 2:4 sparse networks (NVIDIA Ampere NM Sparsity, 2020): Fine-tune the dense model with 2:4 masking applied at each step. The training procedure:
- At each weight update, for each group of 4 weights, zero out the 2 with lowest magnitude
- Update gradient only for non-zero weights
- After fine-tuning, the 2:4 mask is fixed
Accuracy degradation: ResNet-50 at 2:4 sparsity: 0.3% top-1 loss. BERT at 2:4 sparsity: 1.0% F1 loss on SQuAD. For this modest accuracy cost, 2x inference speedup is often acceptable.
Pruning and Neural Tangent Kernels
The Neural Tangent Kernel (NTK) framework (Jacot et al., 2018) analyzes infinite-width networks trained with gradient descent. At infinite width, neural networks are equivalent to kernel machines with a specific kernel — the NTK.
In the NTK regime, training dynamics are deterministic: the NTK doesn’t change during training ($\Theta(x, x’) = J(x) J(x’)^T$ where $J$ is the Jacobian). Predictions converge to the kernel regression solution.
Pruning in the NTK framework: A pruned network with mask $m$ has Jacobian $J_m = J \odot m$. The pruned network’s NTK is $\Theta_m(x, x’) = J_m(x) J_m(x’)^T$. Pruning reduces the kernel’s expressiveness — specifically, its effective rank.
For lottery ticket winning tickets: in the NTK regime, the winning ticket corresponds to a sparse mask that preserves most of the kernel’s eigenvectors. The Lottery Ticket Hypothesis can be viewed as finding sparse masks that approximate the full network’s NTK.
This provides theoretical grounding: winning tickets work because they preserve the kernel structure needed for fast convergence, not just because they happen to have good initialization.
Pruning + Quantization: Combining Compression Methods
Pruning and quantization are complementary:
- Pruning reduces the number of non-zero weights
- Quantization reduces the bits per weight
Combined: “sparse quantization” stores weights in INT4 and eliminates ~80% of them, achieving compression ratios of 20–50x while maintaining reasonable accuracy.
Practical implementation: apply magnitude pruning first, then quantize remaining weights with GPTQ or AWQ. The sparse structure can be encoded in INT2 (positions) + INT4 (values), with the zero weights not stored.
The Pareto frontier of accuracy vs. model size for a given architecture has:
- Dense FP16: maximum accuracy, maximum size
- Dense INT8: ~1% accuracy loss, 2x smaller
- 50% sparse FP16: ~1% accuracy loss, 2x smaller
- Dense INT4: ~2% accuracy loss, 4x smaller
- 50% sparse INT4: ~3% accuracy loss, 8x smaller
- 90% sparse INT4: ~5% accuracy loss, 20-40x smaller
The combined approach sits on the Pareto frontier better than either method alone.
One thing to remember: The theoretical foundation of pruning — from optimal brain damage to neural tangent kernels — reveals that neural networks are massively over-parameterized relative to the information they encode, and systematically exploiting this over-parameterization is the key to efficient deployment.
See Also
- Knowledge Distillation How AI companies shrink massive models down to phone-sized ones without losing much intelligence — the teacher-student trick that powers on-device AI.
- Model Quantization How AI models get shrunk to run on your phone — the precision-tradeoff trick that makes 70 billion parameter models fit in consumer hardware.
- Speculative Decoding The clever trick that makes large AI models generate text 2-4x faster — using a small 'draft' model to guess tokens that a big model then quickly verifies.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.