Knowledge Distillation in Python — Deep Dive
Basic Response-Based Distillation in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha
def forward(self, student_logits, teacher_logits, labels):
# Soft target loss (KL divergence at temperature T)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
distill_loss = F.kl_div(
soft_student,
soft_teacher,
reduction="batchmean"
) * (self.temperature ** 2)
# Hard target loss (standard cross-entropy)
hard_loss = F.cross_entropy(student_logits, labels)
# Combined loss
return self.alpha * distill_loss + (1 - self.alpha) * hard_loss
Complete Training Loop
def distill(teacher, student, train_loader, val_loader,
epochs=50, temperature=4.0, alpha=0.7, lr=0.001):
"""Train student to mimic teacher."""
teacher.eval() # Teacher is frozen
student.train()
optimizer = torch.optim.Adam(student.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
criterion = DistillationLoss(temperature, alpha)
best_accuracy = 0
for epoch in range(epochs):
total_loss = 0
correct = 0
total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.cuda(), labels.cuda()
# Get teacher predictions (no gradient needed)
with torch.no_grad():
teacher_logits = teacher(inputs)
# Get student predictions
student_logits = student(inputs)
# Compute distillation loss
loss = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = student_logits.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
scheduler.step()
train_acc = correct / total
# Validate
val_acc = evaluate(student, val_loader)
print(f"Epoch {epoch+1}/{epochs} | "
f"Loss: {total_loss/len(train_loader):.4f} | "
f"Train: {train_acc:.4f} | Val: {val_acc:.4f}")
if val_acc > best_accuracy:
best_accuracy = val_acc
torch.save(student.state_dict(), "best_student.pth")
print(f"Best validation accuracy: {best_accuracy:.4f}")
return student
Temperature Selection
Temperature is the most important hyperparameter. Here’s how to find the right value:
def evaluate_temperatures(teacher, student_factory, train_loader, val_loader,
temperatures=[1, 2, 3, 4, 5, 7, 10, 20]):
"""Grid search over temperature values."""
results = {}
for temp in temperatures:
student = student_factory().cuda()
distill(teacher, student, train_loader, val_loader,
epochs=30, temperature=temp, alpha=0.7)
accuracy = evaluate(student, val_loader)
results[temp] = accuracy
print(f"T={temp}: {accuracy:.4f}")
best_temp = max(results, key=results.get)
print(f"\nBest temperature: {best_temp} ({results[best_temp]:.4f})")
return results
Guidelines from empirical research:
- Simple tasks (few classes, clear boundaries): T = 2-3
- Complex tasks (many classes, subtle differences): T = 4-8
- Very large teacher with overconfident outputs: T = 10-20
- When classes are hierarchical (dog breeds, plant species): Higher T works better
Feature-Based Distillation
Match intermediate representations, not just outputs:
class FeatureDistillationLoss(nn.Module):
def __init__(self, teacher_channels, student_channels):
super().__init__()
# Alignment layers: map student features to teacher's dimension
self.align = nn.ModuleList([
nn.Conv2d(s_ch, t_ch, kernel_size=1)
for s_ch, t_ch in zip(student_channels, teacher_channels)
])
def forward(self, student_features, teacher_features):
loss = 0
for align, s_feat, t_feat in zip(
self.align, student_features, teacher_features
):
# Align student feature dimensions to teacher's
aligned = align(s_feat)
# Normalize and compute MSE
s_norm = F.normalize(aligned.flatten(2), dim=2)
t_norm = F.normalize(t_feat.flatten(2), dim=2)
loss += F.mse_loss(s_norm, t_norm)
return loss / len(self.align)
# Extract intermediate features during forward pass
class FeatureExtractor(nn.Module):
def __init__(self, model, layer_names):
super().__init__()
self.model = model
self.layer_names = layer_names
self.features = {}
for name in layer_names:
layer = dict(model.named_modules())[name]
layer.register_forward_hook(self._hook(name))
def _hook(self, name):
def fn(module, input, output):
self.features[name] = output
return fn
def forward(self, x):
output = self.model(x)
return output, [self.features[n] for n in self.layer_names]
Attention Transfer
Distill attention maps from teacher to student (Zagoruyko & Komodakis, 2017):
def attention_map(feature_map):
"""Compute spatial attention from a feature map.
Sum of squared activations across channel dimension.
"""
return F.normalize(
feature_map.pow(2).mean(dim=1, keepdim=True).flatten(2),
dim=2
)
class AttentionTransferLoss(nn.Module):
def __init__(self, beta=1000):
super().__init__()
self.beta = beta
def forward(self, student_features, teacher_features):
loss = 0
for s_feat, t_feat in zip(student_features, teacher_features):
s_att = attention_map(s_feat)
t_att = attention_map(t_feat)
loss += (s_att - t_att).pow(2).mean()
return self.beta * loss
Self-Distillation
A model can distill knowledge from itself — training a new version using its own predictions:
def self_distill(model_factory, train_loader, val_loader,
generations=3, temperature=3.0):
"""Iterative self-distillation over multiple generations."""
teacher = model_factory().cuda()
train_standard(teacher, train_loader, val_loader, epochs=100)
teacher_acc = evaluate(teacher, val_loader)
print(f"Generation 0 (baseline): {teacher_acc:.4f}")
for gen in range(1, generations + 1):
student = model_factory().cuda() # Same architecture
distill(teacher, student, train_loader, val_loader,
epochs=100, temperature=temperature)
student_acc = evaluate(student, val_loader)
print(f"Generation {gen}: {student_acc:.4f}")
teacher = student # Student becomes next teacher
return teacher
Self-distillation often improves accuracy by 0.5-2% — the soft targets act as label smoothing and regularization, even with the same architecture.
Distillation for Object Detection
class DetectionDistillationLoss(nn.Module):
"""Distillation for object detection (classification + regression heads)."""
def __init__(self, temperature=3.0, cls_weight=1.0, reg_weight=0.5):
super().__init__()
self.temperature = temperature
self.cls_weight = cls_weight
self.reg_weight = reg_weight
def forward(self, student_cls, student_reg,
teacher_cls, teacher_reg, targets):
# Classification distillation
soft_s = F.log_softmax(student_cls / self.temperature, dim=-1)
soft_t = F.softmax(teacher_cls / self.temperature, dim=-1)
cls_loss = F.kl_div(soft_s, soft_t, reduction="batchmean")
cls_loss *= self.temperature ** 2
# Regression distillation (L2 between box predictions)
reg_loss = F.mse_loss(student_reg, teacher_reg)
# Standard detection loss
det_loss = compute_detection_loss(student_cls, student_reg, targets)
return det_loss + self.cls_weight * cls_loss + self.reg_weight * reg_loss
LLM Distillation Patterns
Distilling large language models into smaller ones:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
def distill_lm_batch(teacher, student, tokenizer, texts,
temperature=2.0, max_length=512):
"""Distill language model knowledge on a batch of texts."""
encodings = tokenizer(
texts, return_tensors="pt", padding=True,
truncation=True, max_length=max_length
).to("cuda")
with torch.no_grad():
teacher_outputs = teacher(**encodings)
teacher_logits = teacher_outputs.logits
student_outputs = student(**encodings)
student_logits = student_outputs.logits
# Token-level KL divergence
# Shape: (batch, seq_len, vocab_size)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
# Mask padding tokens
mask = encodings["attention_mask"].unsqueeze(-1)
kl_loss = F.kl_div(soft_student, soft_teacher, reduction="none")
kl_loss = (kl_loss * mask).sum() / mask.sum()
kl_loss *= temperature ** 2
# Standard language modeling loss
shift_logits = student_logits[:, :-1, :].contiguous()
shift_labels = encodings["input_ids"][:, 1:].contiguous()
lm_loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=tokenizer.pad_token_id
)
return 0.7 * kl_loss + 0.3 * lm_loss
Practical Distillation Results
Real-world compression achieved through distillation:
| Teacher → Student | Task | Teacher Acc | Student Acc | Size Reduction |
|---|---|---|---|---|
| ResNet-152 → ResNet-18 | ImageNet | 78.3% | 73.1% (vs 69.8% from scratch) | 7.6× |
| BERT-Large → BERT-Tiny | GLUE SST-2 | 94.9% | 90.4% (vs 87.1% from scratch) | 28× |
| EfficientNet-B7 → MobileNet V3 | ImageNet | 84.3% | 76.6% (vs 75.2% from scratch) | 22× |
| GPT-3 → DistilGPT-2 | Perplexity | 20.5 | 36.7 (vs 45.3 from scratch) | 200× |
The consistent pattern: distilled students outperform identically-sized models trained from scratch by 1.5-5%, with the gap widening as the student gets smaller.
Distillation + Pruning + Quantization Pipeline
# Step 1: Train teacher
teacher = train_large_model(train_data) # 250M params, 85% acc
# Step 2: Distill to student
student = distill(teacher, small_student, train_data) # 25M params, 82% acc
# Step 3: Prune student
pruned = iterative_prune(student, train_data, sparsity=0.8) # 5M effective params, 81% acc
# Step 4: Quantize
quantized = quantize_int8(pruned) # ~6 MB on disk, 80.5% acc
# Result: 250M params → 6 MB deployable model with 80.5% accuracy
# vs teacher's 85% at ~1 GB
The one thing to remember: Production knowledge distillation in Python combines response-level soft target training (with carefully tuned temperature), feature-level intermediate matching, and integration with pruning and quantization — where the cumulative effect of these techniques can compress a model 100×+ while preserving 90-95% of the teacher’s capability.
See Also
- Python Hyperparameter Tuning Learn why adjusting the dials on a computer's learning recipe makes predictions way better.
- Python Model Compression Methods All the ways Python developers shrink massive AI models to fit on phones and tiny devices — like packing for a trip with a carry-on bag.
- Python Model Pruning Techniques Why cutting away parts of an AI's brain can make it faster without making it dumber.
- Python Neural Architecture Search How AI designs its own brain structure — like a robot architect building the perfect house by trying thousands of floor plans.
- Python Pytorch Quantization How shrinking numbers inside an AI model makes it run faster on phones and cheaper servers without losing much accuracy.