Model Compression Methods in Python — Deep Dive
The Full Compression Pipeline
This deep dive implements a complete compression pipeline: starting from a large trained model and producing a deployment-ready artifact. We’ll measure size and accuracy at each stage.
Stage 0: Baseline Model
import torch
import torch.nn as nn
import torchvision.models as models
import os
# Load a pre-trained ResNet-50
teacher = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
teacher.eval()
def model_size_mb(model):
"""Calculate model size in MB."""
torch.save(model.state_dict(), "/tmp/temp_model.pth")
size = os.path.getsize("/tmp/temp_model.pth") / (1024 * 1024)
os.remove("/tmp/temp_model.pth")
return size
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
def count_nonzero(model):
total = 0
nonzero = 0
for p in model.parameters():
total += p.numel()
nonzero += (p != 0).sum().item()
return nonzero, total
print(f"Teacher: {count_parameters(teacher)/1e6:.1f}M params, {model_size_mb(teacher):.1f} MB")
# Teacher: 25.6M params, 97.8 MB
Stage 1: Knowledge Distillation
Distill ResNet-50 into MobileNet V3 Small:
import torch.nn.functional as F
# Student: MobileNet V3 Small (2.5M params vs 25.6M)
student = models.mobilenet_v3_small(num_classes=1000)
class DistillationTrainer:
def __init__(self, teacher, student, temperature=4.0, alpha=0.7):
self.teacher = teacher.cuda().eval()
self.student = student.cuda()
self.temperature = temperature
self.alpha = alpha
self.optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=100
)
def train_epoch(self, dataloader):
self.student.train()
total_loss = 0
for inputs, labels in dataloader:
inputs, labels = inputs.cuda(), labels.cuda()
with torch.no_grad():
teacher_logits = self.teacher(inputs)
student_logits = self.student(inputs)
# Distillation loss
soft_loss = F.kl_div(
F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1),
reduction="batchmean"
) * (self.temperature ** 2)
hard_loss = F.cross_entropy(student_logits, labels)
loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
self.scheduler.step()
return total_loss / len(dataloader)
# After distillation:
print(f"Student: {count_parameters(student)/1e6:.1f}M params, {model_size_mb(student):.1f} MB")
# Student: 2.5M params, 9.7 MB
# Compression: 10.1×, accuracy: ~67% (vs teacher's 80%)
Stage 2: Structured Pruning
import torch.nn.utils.prune as prune
def apply_structured_pruning(model, amount=0.4):
"""Prune 40% of channels from all Conv2d layers."""
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
return model
def get_sparsity(model):
zeros, total = 0, 0
for p in model.parameters():
zeros += (p == 0).sum().item()
total += p.numel()
return zeros / total
student = apply_structured_pruning(student, amount=0.4)
# Fine-tune after pruning
finetune_optimizer = torch.optim.SGD(
student.parameters(), lr=1e-3, momentum=0.9
)
for epoch in range(10):
student.train()
for inputs, labels in train_loader:
inputs, labels = inputs.cuda(), labels.cuda()
loss = F.cross_entropy(student(inputs), labels)
finetune_optimizer.zero_grad()
loss.backward()
finetune_optimizer.step()
# Make pruning permanent
for name, module in student.named_modules():
if isinstance(module, nn.Conv2d) and hasattr(module, "weight_mask"):
prune.remove(module, "weight")
print(f"After pruning: sparsity={get_sparsity(student):.1%}, {model_size_mb(student):.1f} MB")
# After pruning: sparsity=40.0%, ~9.7 MB (same storage — zeros still stored)
Stage 3: Weight Clustering
Group weights into K clusters to enable compact storage:
from sklearn.cluster import KMeans
import numpy as np
def cluster_weights(model, num_clusters=256):
"""Replace weights with cluster centroids."""
codebooks = {}
for name, param in model.named_parameters():
if param.dim() < 2: # Skip biases and 1D params
continue
weights = param.data.cpu().numpy().flatten()
# Skip zero weights (from pruning)
nonzero_mask = weights != 0
if nonzero_mask.sum() == 0:
continue
nonzero_weights = weights[nonzero_mask].reshape(-1, 1)
# Cluster non-zero weights
k = min(num_clusters, len(nonzero_weights))
kmeans = KMeans(n_clusters=k, n_init=1, max_iter=50)
kmeans.fit(nonzero_weights)
# Replace weights with centroids
clustered = kmeans.cluster_centers_[kmeans.labels_].flatten()
weights[nonzero_mask] = clustered
param.data = torch.tensor(
weights.reshape(param.shape),
dtype=param.dtype,
device=param.device
)
codebooks[name] = {
"centroids": kmeans.cluster_centers_.flatten(),
"labels": kmeans.labels_,
"shape": param.shape,
"nonzero_mask": nonzero_mask
}
return codebooks
codebooks = cluster_weights(student, num_clusters=32)
# Storage calculation: 32 clusters → 5 bits per weight
# vs 32 bits per weight → 6.4× compression on non-zero weights
Stage 4: Quantization
import torch.ao.quantization as quant
def quantize_model(model, calibration_loader):
"""Post-training static INT8 quantization."""
model.eval().cpu()
# Fuse common patterns for better quantization
model_fused = torch.ao.quantization.fuse_modules(model, [
["features.0.0", "features.0.1", "features.0.2"],
# ... add all fuseable patterns
])
# Configure quantization
model_fused.qconfig = torch.ao.quantization.get_default_qconfig("x86")
torch.ao.quantization.prepare(model_fused, inplace=True)
# Calibrate with representative data
with torch.no_grad():
for inputs, _ in calibration_loader:
model_fused(inputs)
# Convert to quantized model
quantized = torch.ao.quantization.convert(model_fused, inplace=False)
return quantized
quantized_student = quantize_model(student, calibration_loader)
print(f"Quantized: {model_size_mb(quantized_student):.1f} MB")
# Quantized: ~2.5 MB
Stage 5: Export for Deployment
# Option A: TorchScript
scripted = torch.jit.script(quantized_student)
scripted.save("compressed_model.pt")
# Option B: ONNX export
dummy = torch.randn(1, 3, 224, 224)
torch.onnx.export(
student, # Use float model for ONNX (quantize in ONNX Runtime)
dummy,
"compressed_model.onnx",
opset_version=13,
input_names=["input"],
output_names=["output"]
)
# Option C: TFLite (via ONNX → TF → TFLite)
# Best for mobile and microcontroller deployment
ONNX Runtime Quantization
For framework-agnostic INT8 deployment:
from onnxruntime.quantization import (
quantize_static,
quantize_dynamic,
CalibrationDataReader,
QuantFormat,
QuantType
)
class ImageNetCalibrationReader(CalibrationDataReader):
def __init__(self, calibration_dir, num_samples=100):
self.data = self._load_data(calibration_dir, num_samples)
self.idx = 0
def get_next(self):
if self.idx >= len(self.data):
return None
sample = {"input": self.data[self.idx]}
self.idx += 1
return sample
def _load_data(self, path, n):
# Load and preprocess calibration images
return [preprocess(img) for img in load_images(path)[:n]]
# Static quantization (best accuracy)
quantize_static(
model_input="compressed_model.onnx",
model_output="compressed_model_int8.onnx",
calibration_data_reader=ImageNetCalibrationReader("cal_data/"),
quant_format=QuantFormat.QDQ,
weight_type=QuantType.QInt8,
activation_type=QuantType.QInt8
)
# Dynamic quantization (no calibration needed)
quantize_dynamic(
model_input="compressed_model.onnx",
model_output="compressed_model_dynamic.onnx",
weight_type=QuantType.QInt8
)
Low-Rank Factorization
Decompose large weight matrices into products of smaller matrices:
import torch
from torch import nn
def low_rank_decomposition(weight, rank):
"""Decompose weight matrix using SVD."""
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
# Keep top-k singular values
U_k = U[:, :rank]
S_k = torch.diag(S[:rank])
Vh_k = Vh[:rank, :]
# Two smaller matrices: A = U_k @ sqrt(S_k), B = sqrt(S_k) @ Vh_k
sqrt_S = torch.diag(torch.sqrt(S[:rank]))
A = U_k @ sqrt_S # (out_features, rank)
B = sqrt_S @ Vh_k # (rank, in_features)
return A, B
def replace_linear_with_low_rank(model, rank_ratio=0.5):
"""Replace Linear layers with low-rank approximations."""
for name, module in model.named_children():
if isinstance(module, nn.Linear):
in_f, out_f = module.in_features, module.out_features
rank = int(min(in_f, out_f) * rank_ratio)
# Only decompose if it saves parameters
original_params = in_f * out_f
new_params = in_f * rank + rank * out_f
if new_params >= original_params:
continue
A, B = low_rank_decomposition(module.weight.data, rank)
# Replace with two smaller linear layers
replacement = nn.Sequential(
nn.Linear(in_f, rank, bias=False),
nn.Linear(rank, out_f, bias=module.bias is not None)
)
replacement[0].weight.data = B
replacement[1].weight.data = A
if module.bias is not None:
replacement[1].bias.data = module.bias.data
setattr(model, name, replacement)
savings = (1 - new_params / original_params) * 100
print(f"{name}: rank {rank}, {savings:.0f}% parameter reduction")
else:
replace_linear_with_low_rank(module, rank_ratio)
return model
Comprehensive Benchmarking Framework
import time
import numpy as np
class CompressionBenchmark:
def __init__(self, val_loader, device="cuda"):
self.val_loader = val_loader
self.device = device
def evaluate(self, model, name="model"):
"""Complete evaluation: accuracy, size, latency."""
model.eval()
# Accuracy
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in self.val_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)
outputs = model(inputs)
_, predicted = outputs.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
accuracy = correct / total
# Size
size_mb = model_size_mb(model)
nonzero, total_params = count_nonzero(model)
sparsity = 1 - nonzero / total_params
# Latency
dummy = torch.randn(1, 3, 224, 224).to(self.device)
model(dummy) # Warmup
times = []
with torch.no_grad():
for _ in range(100):
start = time.perf_counter()
model(dummy)
if self.device == "cuda":
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
result = {
"name": name,
"accuracy": accuracy,
"size_mb": size_mb,
"params_m": total_params / 1e6,
"nonzero_m": nonzero / 1e6,
"sparsity": sparsity,
"latency_ms": np.mean(times) * 1000,
"latency_p95_ms": np.percentile(times, 95) * 1000,
}
print(f"\n{'='*50}")
print(f" {name}")
print(f"{'='*50}")
print(f" Accuracy: {accuracy:.2%}")
print(f" Size: {size_mb:.1f} MB")
print(f" Params: {total_params/1e6:.1f}M ({nonzero/1e6:.1f}M non-zero)")
print(f" Sparsity: {sparsity:.1%}")
print(f" Latency: {np.mean(times)*1000:.2f}ms (P95: {np.percentile(times,95)*1000:.2f}ms)")
return result
# Run complete pipeline benchmark
benchmark = CompressionBenchmark(val_loader)
results = []
results.append(benchmark.evaluate(teacher, "Teacher (ResNet-50)"))
results.append(benchmark.evaluate(student, "Distilled (MobileNet V3)"))
results.append(benchmark.evaluate(pruned_student, "Distilled + Pruned"))
results.append(benchmark.evaluate(quantized_student, "Distilled + Pruned + Quantized"))
Expected Pipeline Results
| Stage | Size | Accuracy | Latency | Compression |
|---|---|---|---|---|
| Teacher (ResNet-50) | 97.8 MB | 80.4% | 12.3ms | 1× |
| Distilled (MobileNet V3) | 9.7 MB | 67.5% | 2.1ms | 10× |
| + Structured Pruning | 9.7 MB* | 66.8% | 1.4ms | 10× |
| + INT8 Quantization | 2.5 MB | 66.2% | 0.8ms | 39× |
| + Weight Clustering | ~1.5 MB | 65.9% | 0.8ms | 65× |
*Pruned model has same file size until converted to sparse format or architecture is rebuilt.
The one thing to remember: A production compression pipeline in Python chains distillation → pruning → quantization, measuring accuracy, size, and latency at each stage — where the cumulative effect achieves 40-100× compression, but the specific combination and aggressiveness at each stage must be tuned to your deployment target’s hardware capabilities and your minimum accuracy threshold.
See Also
- Python Hyperparameter Tuning Learn why adjusting the dials on a computer's learning recipe makes predictions way better.
- Python Knowledge Distillation How a big expert AI teaches a tiny student AI to be almost as smart — like a professor writing a cheat sheet for an exam.
- Python Model Pruning Techniques Why cutting away parts of an AI's brain can make it faster without making it dumber.
- Python Neural Architecture Search How AI designs its own brain structure — like a robot architect building the perfect house by trying thousands of floor plans.
- Python Pytorch Quantization How shrinking numbers inside an AI model makes it run faster on phones and cheaper servers without losing much accuracy.