Knowledge Distillation in Python — Core Concepts

The Problem Distillation Solves

Large models perform well because they have the capacity to learn complex patterns. But their size makes them impractical for real-time inference, mobile deployment, or cost-sensitive serving. Training a smaller model from scratch on the same data typically produces worse results — small models lack the capacity to extract the same patterns directly from raw data.

Knowledge distillation bridges this gap. The insight, first formalized by Hinton, Vinyals, and Dean in 2015, is that a large model’s output contains more useful information than the original training labels.

Hard Labels vs Soft Targets

Consider an image classifier trained to recognize animals. The training data says a particular image is a “cat” — that’s a hard label: [1, 0, 0, 0] for [cat, dog, bird, fish].

But the teacher model produces soft probabilities: [0.85, 0.10, 0.04, 0.01]. This tells us the image is mostly cat, but somewhat dog-like, barely bird-like, and definitely not fish. This “dark knowledge” encodes relationships between classes that hard labels don’t capture.

The student learning from soft targets gets two signals:

  1. What the correct answer is
  2. What the incorrect answers’ relative likelihoods are

That second signal is incredibly valuable. It teaches the student about the structure of the problem, not just individual answers.

Temperature Scaling

The teacher’s raw outputs (logits) are often very peaked — 99.9% cat, 0.1% everything else. This makes the soft targets almost as uninformative as hard labels.

Temperature fixes this. Before computing probabilities, divide the logits by a temperature value T:

softmax(logits / T)
  • T = 1: Standard softmax (peaked, confident)
  • T = 3-5: Softer distribution that reveals more about relationships between classes
  • T = 10+: Very smooth, almost uniform — too much information loss

Higher temperature “softens” the teacher’s predictions, making the dark knowledge more visible. The sweet spot is typically T = 3 to 5 for most tasks.

Both teacher and student use the same temperature during distillation. At inference time, the student uses T = 1 (standard predictions).

The Distillation Loss

The student optimizes a combined loss:

Loss = α × KL(soft_student, soft_teacher) + (1 - α) × CE(hard_student, true_label)
  • KL divergence between student and teacher soft predictions (with temperature)
  • Cross-entropy between student predictions and true labels
  • α controls the balance (typically 0.5 to 0.9 — more weight on teacher signal)

The KL term teaches the student to mimic the teacher. The CE term keeps it grounded in actual correctness. The temperature-scaled KL term is multiplied by T² to maintain gradient magnitudes.

Distillation Variants

Response-Based Distillation

The student learns from the teacher’s final output predictions. This is the classic approach described above.

Feature-Based Distillation

The student learns to match the teacher’s intermediate representations — hidden layer activations, attention maps, or feature embeddings. More powerful but requires architectural alignment between teacher and student.

Relation-Based Distillation

The student learns the relationships between examples as encoded by the teacher. If the teacher sees samples A and B as similar, the student should too.

When Distillation Works Best

ScenarioCompressionAccuracy Retention
Large → Medium (same architecture)2-5×98-99%
Large → Small (different architecture)10-50×93-97%
Ensemble → Single model3-10×96-99%
Large language model → Small LM10-100×85-95%

Distillation works best when:

  • Teacher is significantly more capable than what the task requires
  • The student architecture is well-suited to the task
  • Sufficient unlabeled data is available (teacher can label it)

Common Misconception

“The student can never be as good as the teacher.” While the student typically doesn’t fully match the teacher, it sometimes exceeds the teacher on specific subsets or metrics. This happens because distillation acts as a strong regularizer — the student receives a smoother, less noisy signal than raw data provides. In some cases, a distilled student outperforms a model of the same size trained from scratch by a large margin.

The one thing to remember: Knowledge distillation transfers a teacher’s expertise through soft probability targets at elevated temperature, where the “dark knowledge” in the teacher’s non-top predictions teaches the student about class relationships — enabling dramatic model compression with minimal accuracy loss.

pythonmachine-learningmodel-optimization

See Also