Knowledge Distillation — Deep Dive
Why Soft Labels Work: Information Theory View
In information theory terms, a one-hot label contains $\log_2(C)$ bits of information (where $C$ is the number of classes) — it just identifies which class. A probability distribution over $C$ classes contains up to $\log_2(C)$ bits but typically less (it’s redundant when peaked).
However, the information content relevant for generalization is different. The probability ratios between similar classes encode structural information about the data manifold. For handwritten digits: the fact that a specific “2” is 4% likely to be “8” encodes that 2s and 8s are visually similar — information about the geometry of handwritten digit space.
Hinton et al. (2015) called this the “dark knowledge” — information not directly useful for the current training example but important for generalization. The theoretical justification: maximizing the likelihood of the soft target distribution is equivalent to minimizing the KL divergence from the teacher’s distribution:
$$\mathcal{L}{soft} = KL(p{teacher}^T || p_{student}^T) = -\sum_i p_{teacher,i}^T \log p_{student,i}^T + const$$
Where the constant (entropy of teacher distribution) is independent of student parameters.
Temperature Analysis
The temperature parameter $T$ controls the information flow from teacher to student.
At $T \rightarrow 0$: the soft targets approach hard labels (all probability mass on the argmax). No dark knowledge transferred.
At $T \rightarrow \infty$: all soft targets approach $1/C$ (uniform distribution). All class structure is lost.
There’s an optimal $T$ for each teacher-student pair that maximizes useful information transfer. Empirically, $T \in [2, 6]$ works well for classification. The right value depends on how peaked the teacher’s distributions are — a confident teacher benefits from higher temperature to reveal class structure.
Why $T^2$ scales the soft loss: When computing softmax at temperature $T$, the gradients of the cross-entropy loss with respect to pre-softmax logits are $\frac{1}{T}$ times the gradients at $T=1$. To keep the gradient magnitudes comparable to the hard label loss, the soft loss is scaled by $T^2$:
$$\frac{\partial \mathcal{L}{soft}}{\partial z_i} = \frac{1}{T}(p{student,i}^T - p_{teacher,i}^T)$$
Multiplying by $T^2$ corrects for the $1/T$ factor, giving gradients of the same order as $T=1$ cross-entropy.
Self-Distillation
Furlanello et al. (2018) “Born-Again Networks” showed a surprising result: distilling a model into another model of the same size improves performance. Each distillation “generation” produces a slightly better model, even with identical architecture.
The mechanism: the teacher’s soft labels act as label smoothing, preventing overconfident predictions. Additionally, training on teacher logits rather than one-hot labels provides a smoother loss landscape.
Successive self-distillation: Train model $G_1$ on hard labels. Distill $G_1 \rightarrow G_2$ (same architecture). Distill $G_2 \rightarrow G_3$. Ensemble multiple generations. Each generation outperforms the previous.
This has practical implications: even without model compression goals, distillation can improve model quality — treating the original model as a teacher and retraining from scratch with soft labels.
Data-Free Distillation
A fundamental limitation: standard distillation requires the original training data to generate teacher activations. This is often unavailable due to privacy constraints, data licensing issues, or proprietary training data.
Data-free knowledge distillation generates synthetic data that activates the teacher’s knowledge. Two main approaches:
Inversion-based (DAFL, Yin et al. 2020): Optimize a generator to produce inputs that maximize the student-teacher disagreement. The generator creates informative synthetic examples; the student is trained on these.
GAN-based (DFAD, Fang et al. 2021): Train a generator adversarially against the teacher — the generator tries to produce examples where teacher and student disagree, providing the hardest training signal.
These methods achieve within 2–5% of standard distillation on benchmarks like CIFAR-100, which is remarkable given no original data access.
LLM-Specific Distillation
Distilling large language models involves additional complexities:
Sequence-level distillation (Kim & Rush, 2016): Rather than token-level logit matching, generate sequences from the teacher and fine-tune the student on these. The student learns from teacher-generated examples, not just teacher probabilities. This is the basis for dataset distillation approaches used in modern LLM training.
Speculative decoding (Leviathan et al., 2022): Use a small draft model to propose tokens, then verify with the large model. Not traditional distillation, but achieves 2–3x inference speedup by running the large model only for verification rather than generation.
State Space Model distillation: Distilling transformer LLMs into Mamba-style SSMs. The student architecture is fundamentally different (no attention). This requires feature-level distillation — matching hidden state representations rather than just output logits.
Curriculum in LLM distillation: For large vocabulary LLMs, the teacher’s distribution over vocabulary at each token can have thousands of non-negligible probabilities. Sparse distillation (top-k soft targets, zeroing the rest) reduces computation while preserving most information.
Practical Considerations at Scale
For distilling 70B+ parameter models to smaller models:
Layer mapping: With a 32-layer teacher and 12-layer student, which teacher layers should correspond to which student layers? Uniform mapping (every 2.67 teacher layers → 1 student layer) works but isn’t optimal. Learned layer selection (choosing which teacher layers to distill from) improves results.
Projection layers: Teacher and student hidden dimensions typically differ (e.g., 4096 vs. 1024). Linear projections handle the dimension mismatch for feature-level distillation, but these add parameters.
Computational cost: Distillation requires running both teacher and student forward passes. For a 70B teacher and 7B student, this is 11x more compute per batch than training the student alone. Techniques like precomputing and caching teacher activations make this practical for static datasets.
One thing to remember: Distillation’s power comes from the teacher’s probability distributions encoding the geometry of the problem space — the more information you can extract from those distributions (through temperature, feature matching, and relational approaches), the better the student can learn.
See Also
- Model Pruning How AI models lose weight without losing intelligence — removing the neurons that don't actually do anything useful to make models faster and smaller.
- Model Quantization How AI models get shrunk to run on your phone — the precision-tradeoff trick that makes 70 billion parameter models fit in consumer hardware.
- Speculative Decoding The clever trick that makes large AI models generate text 2-4x faster — using a small 'draft' model to guess tokens that a big model then quickly verifies.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.