Knowledge Distillation in Python — Deep Dive

Basic Response-Based Distillation in PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.7):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, student_logits, teacher_logits, labels):
        # Soft target loss (KL divergence at temperature T)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        distill_loss = F.kl_div(
            soft_student,
            soft_teacher,
            reduction="batchmean"
        ) * (self.temperature ** 2)

        # Hard target loss (standard cross-entropy)
        hard_loss = F.cross_entropy(student_logits, labels)

        # Combined loss
        return self.alpha * distill_loss + (1 - self.alpha) * hard_loss

Complete Training Loop

def distill(teacher, student, train_loader, val_loader,
            epochs=50, temperature=4.0, alpha=0.7, lr=0.001):
    """Train student to mimic teacher."""

    teacher.eval()  # Teacher is frozen
    student.train()

    optimizer = torch.optim.Adam(student.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = DistillationLoss(temperature, alpha)

    best_accuracy = 0

    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.cuda(), labels.cuda()

            # Get teacher predictions (no gradient needed)
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Get student predictions
            student_logits = student(inputs)

            # Compute distillation loss
            loss = criterion(student_logits, teacher_logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = student_logits.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        scheduler.step()
        train_acc = correct / total

        # Validate
        val_acc = evaluate(student, val_loader)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Loss: {total_loss/len(train_loader):.4f} | "
              f"Train: {train_acc:.4f} | Val: {val_acc:.4f}")

        if val_acc > best_accuracy:
            best_accuracy = val_acc
            torch.save(student.state_dict(), "best_student.pth")

    print(f"Best validation accuracy: {best_accuracy:.4f}")
    return student

Temperature Selection

Temperature is the most important hyperparameter. Here’s how to find the right value:

def evaluate_temperatures(teacher, student_factory, train_loader, val_loader,
                          temperatures=[1, 2, 3, 4, 5, 7, 10, 20]):
    """Grid search over temperature values."""
    results = {}

    for temp in temperatures:
        student = student_factory().cuda()
        distill(teacher, student, train_loader, val_loader,
                epochs=30, temperature=temp, alpha=0.7)
        accuracy = evaluate(student, val_loader)
        results[temp] = accuracy
        print(f"T={temp}: {accuracy:.4f}")

    best_temp = max(results, key=results.get)
    print(f"\nBest temperature: {best_temp} ({results[best_temp]:.4f})")
    return results

Guidelines from empirical research:

  • Simple tasks (few classes, clear boundaries): T = 2-3
  • Complex tasks (many classes, subtle differences): T = 4-8
  • Very large teacher with overconfident outputs: T = 10-20
  • When classes are hierarchical (dog breeds, plant species): Higher T works better

Feature-Based Distillation

Match intermediate representations, not just outputs:

class FeatureDistillationLoss(nn.Module):
    def __init__(self, teacher_channels, student_channels):
        super().__init__()
        # Alignment layers: map student features to teacher's dimension
        self.align = nn.ModuleList([
            nn.Conv2d(s_ch, t_ch, kernel_size=1)
            for s_ch, t_ch in zip(student_channels, teacher_channels)
        ])

    def forward(self, student_features, teacher_features):
        loss = 0
        for align, s_feat, t_feat in zip(
            self.align, student_features, teacher_features
        ):
            # Align student feature dimensions to teacher's
            aligned = align(s_feat)

            # Normalize and compute MSE
            s_norm = F.normalize(aligned.flatten(2), dim=2)
            t_norm = F.normalize(t_feat.flatten(2), dim=2)
            loss += F.mse_loss(s_norm, t_norm)

        return loss / len(self.align)


# Extract intermediate features during forward pass
class FeatureExtractor(nn.Module):
    def __init__(self, model, layer_names):
        super().__init__()
        self.model = model
        self.layer_names = layer_names
        self.features = {}

        for name in layer_names:
            layer = dict(model.named_modules())[name]
            layer.register_forward_hook(self._hook(name))

    def _hook(self, name):
        def fn(module, input, output):
            self.features[name] = output
        return fn

    def forward(self, x):
        output = self.model(x)
        return output, [self.features[n] for n in self.layer_names]

Attention Transfer

Distill attention maps from teacher to student (Zagoruyko & Komodakis, 2017):

def attention_map(feature_map):
    """Compute spatial attention from a feature map.
    
    Sum of squared activations across channel dimension.
    """
    return F.normalize(
        feature_map.pow(2).mean(dim=1, keepdim=True).flatten(2),
        dim=2
    )

class AttentionTransferLoss(nn.Module):
    def __init__(self, beta=1000):
        super().__init__()
        self.beta = beta

    def forward(self, student_features, teacher_features):
        loss = 0
        for s_feat, t_feat in zip(student_features, teacher_features):
            s_att = attention_map(s_feat)
            t_att = attention_map(t_feat)
            loss += (s_att - t_att).pow(2).mean()
        return self.beta * loss

Self-Distillation

A model can distill knowledge from itself — training a new version using its own predictions:

def self_distill(model_factory, train_loader, val_loader,
                 generations=3, temperature=3.0):
    """Iterative self-distillation over multiple generations."""

    teacher = model_factory().cuda()
    train_standard(teacher, train_loader, val_loader, epochs=100)
    teacher_acc = evaluate(teacher, val_loader)
    print(f"Generation 0 (baseline): {teacher_acc:.4f}")

    for gen in range(1, generations + 1):
        student = model_factory().cuda()  # Same architecture
        distill(teacher, student, train_loader, val_loader,
                epochs=100, temperature=temperature)
        student_acc = evaluate(student, val_loader)
        print(f"Generation {gen}: {student_acc:.4f}")

        teacher = student  # Student becomes next teacher

    return teacher

Self-distillation often improves accuracy by 0.5-2% — the soft targets act as label smoothing and regularization, even with the same architecture.

Distillation for Object Detection

class DetectionDistillationLoss(nn.Module):
    """Distillation for object detection (classification + regression heads)."""

    def __init__(self, temperature=3.0, cls_weight=1.0, reg_weight=0.5):
        super().__init__()
        self.temperature = temperature
        self.cls_weight = cls_weight
        self.reg_weight = reg_weight

    def forward(self, student_cls, student_reg,
                teacher_cls, teacher_reg, targets):
        # Classification distillation
        soft_s = F.log_softmax(student_cls / self.temperature, dim=-1)
        soft_t = F.softmax(teacher_cls / self.temperature, dim=-1)
        cls_loss = F.kl_div(soft_s, soft_t, reduction="batchmean")
        cls_loss *= self.temperature ** 2

        # Regression distillation (L2 between box predictions)
        reg_loss = F.mse_loss(student_reg, teacher_reg)

        # Standard detection loss
        det_loss = compute_detection_loss(student_cls, student_reg, targets)

        return det_loss + self.cls_weight * cls_loss + self.reg_weight * reg_loss

LLM Distillation Patterns

Distilling large language models into smaller ones:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

def distill_lm_batch(teacher, student, tokenizer, texts,
                     temperature=2.0, max_length=512):
    """Distill language model knowledge on a batch of texts."""

    encodings = tokenizer(
        texts, return_tensors="pt", padding=True,
        truncation=True, max_length=max_length
    ).to("cuda")

    with torch.no_grad():
        teacher_outputs = teacher(**encodings)
        teacher_logits = teacher_outputs.logits

    student_outputs = student(**encodings)
    student_logits = student_outputs.logits

    # Token-level KL divergence
    # Shape: (batch, seq_len, vocab_size)
    soft_student = F.log_softmax(student_logits / temperature, dim=-1)
    soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)

    # Mask padding tokens
    mask = encodings["attention_mask"].unsqueeze(-1)
    kl_loss = F.kl_div(soft_student, soft_teacher, reduction="none")
    kl_loss = (kl_loss * mask).sum() / mask.sum()
    kl_loss *= temperature ** 2

    # Standard language modeling loss
    shift_logits = student_logits[:, :-1, :].contiguous()
    shift_labels = encodings["input_ids"][:, 1:].contiguous()
    lm_loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=tokenizer.pad_token_id
    )

    return 0.7 * kl_loss + 0.3 * lm_loss

Practical Distillation Results

Real-world compression achieved through distillation:

Teacher → StudentTaskTeacher AccStudent AccSize Reduction
ResNet-152 → ResNet-18ImageNet78.3%73.1% (vs 69.8% from scratch)7.6×
BERT-Large → BERT-TinyGLUE SST-294.9%90.4% (vs 87.1% from scratch)28×
EfficientNet-B7 → MobileNet V3ImageNet84.3%76.6% (vs 75.2% from scratch)22×
GPT-3 → DistilGPT-2Perplexity20.536.7 (vs 45.3 from scratch)200×

The consistent pattern: distilled students outperform identically-sized models trained from scratch by 1.5-5%, with the gap widening as the student gets smaller.

Distillation + Pruning + Quantization Pipeline

# Step 1: Train teacher
teacher = train_large_model(train_data)  # 250M params, 85% acc

# Step 2: Distill to student
student = distill(teacher, small_student, train_data)  # 25M params, 82% acc

# Step 3: Prune student
pruned = iterative_prune(student, train_data, sparsity=0.8)  # 5M effective params, 81% acc

# Step 4: Quantize
quantized = quantize_int8(pruned)  # ~6 MB on disk, 80.5% acc

# Result: 250M params → 6 MB deployable model with 80.5% accuracy
# vs teacher's 85% at ~1 GB

The one thing to remember: Production knowledge distillation in Python combines response-level soft target training (with carefully tuned temperature), feature-level intermediate matching, and integration with pruning and quantization — where the cumulative effect of these techniques can compress a model 100×+ while preserving 90-95% of the teacher’s capability.

pythonmachine-learningmodel-optimization

See Also