Model Compression Methods in Python — Deep Dive

The Full Compression Pipeline

This deep dive implements a complete compression pipeline: starting from a large trained model and producing a deployment-ready artifact. We’ll measure size and accuracy at each stage.

Stage 0: Baseline Model

import torch
import torch.nn as nn
import torchvision.models as models
import os

# Load a pre-trained ResNet-50
teacher = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
teacher.eval()

def model_size_mb(model):
    """Calculate model size in MB."""
    torch.save(model.state_dict(), "/tmp/temp_model.pth")
    size = os.path.getsize("/tmp/temp_model.pth") / (1024 * 1024)
    os.remove("/tmp/temp_model.pth")
    return size

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def count_nonzero(model):
    total = 0
    nonzero = 0
    for p in model.parameters():
        total += p.numel()
        nonzero += (p != 0).sum().item()
    return nonzero, total

print(f"Teacher: {count_parameters(teacher)/1e6:.1f}M params, {model_size_mb(teacher):.1f} MB")
# Teacher: 25.6M params, 97.8 MB

Stage 1: Knowledge Distillation

Distill ResNet-50 into MobileNet V3 Small:

import torch.nn.functional as F

# Student: MobileNet V3 Small (2.5M params vs 25.6M)
student = models.mobilenet_v3_small(num_classes=1000)

class DistillationTrainer:
    def __init__(self, teacher, student, temperature=4.0, alpha=0.7):
        self.teacher = teacher.cuda().eval()
        self.student = student.cuda()
        self.temperature = temperature
        self.alpha = alpha
        self.optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=100
        )

    def train_epoch(self, dataloader):
        self.student.train()
        total_loss = 0

        for inputs, labels in dataloader:
            inputs, labels = inputs.cuda(), labels.cuda()

            with torch.no_grad():
                teacher_logits = self.teacher(inputs)

            student_logits = self.student(inputs)

            # Distillation loss
            soft_loss = F.kl_div(
                F.log_softmax(student_logits / self.temperature, dim=1),
                F.softmax(teacher_logits / self.temperature, dim=1),
                reduction="batchmean"
            ) * (self.temperature ** 2)

            hard_loss = F.cross_entropy(student_logits, labels)
            loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()

        self.scheduler.step()
        return total_loss / len(dataloader)

# After distillation:
print(f"Student: {count_parameters(student)/1e6:.1f}M params, {model_size_mb(student):.1f} MB")
# Student: 2.5M params, 9.7 MB
# Compression: 10.1×, accuracy: ~67% (vs teacher's 80%)

Stage 2: Structured Pruning

import torch.nn.utils.prune as prune

def apply_structured_pruning(model, amount=0.4):
    """Prune 40% of channels from all Conv2d layers."""
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)

    return model

def get_sparsity(model):
    zeros, total = 0, 0
    for p in model.parameters():
        zeros += (p == 0).sum().item()
        total += p.numel()
    return zeros / total

student = apply_structured_pruning(student, amount=0.4)

# Fine-tune after pruning
finetune_optimizer = torch.optim.SGD(
    student.parameters(), lr=1e-3, momentum=0.9
)

for epoch in range(10):
    student.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.cuda(), labels.cuda()
        loss = F.cross_entropy(student(inputs), labels)
        finetune_optimizer.zero_grad()
        loss.backward()
        finetune_optimizer.step()

# Make pruning permanent
for name, module in student.named_modules():
    if isinstance(module, nn.Conv2d) and hasattr(module, "weight_mask"):
        prune.remove(module, "weight")

print(f"After pruning: sparsity={get_sparsity(student):.1%}, {model_size_mb(student):.1f} MB")
# After pruning: sparsity=40.0%, ~9.7 MB (same storage — zeros still stored)

Stage 3: Weight Clustering

Group weights into K clusters to enable compact storage:

from sklearn.cluster import KMeans
import numpy as np

def cluster_weights(model, num_clusters=256):
    """Replace weights with cluster centroids."""
    codebooks = {}

    for name, param in model.named_parameters():
        if param.dim() < 2:  # Skip biases and 1D params
            continue

        weights = param.data.cpu().numpy().flatten()

        # Skip zero weights (from pruning)
        nonzero_mask = weights != 0
        if nonzero_mask.sum() == 0:
            continue

        nonzero_weights = weights[nonzero_mask].reshape(-1, 1)

        # Cluster non-zero weights
        k = min(num_clusters, len(nonzero_weights))
        kmeans = KMeans(n_clusters=k, n_init=1, max_iter=50)
        kmeans.fit(nonzero_weights)

        # Replace weights with centroids
        clustered = kmeans.cluster_centers_[kmeans.labels_].flatten()
        weights[nonzero_mask] = clustered
        param.data = torch.tensor(
            weights.reshape(param.shape),
            dtype=param.dtype,
            device=param.device
        )

        codebooks[name] = {
            "centroids": kmeans.cluster_centers_.flatten(),
            "labels": kmeans.labels_,
            "shape": param.shape,
            "nonzero_mask": nonzero_mask
        }

    return codebooks

codebooks = cluster_weights(student, num_clusters=32)

# Storage calculation: 32 clusters → 5 bits per weight
# vs 32 bits per weight → 6.4× compression on non-zero weights

Stage 4: Quantization

import torch.ao.quantization as quant

def quantize_model(model, calibration_loader):
    """Post-training static INT8 quantization."""
    model.eval().cpu()

    # Fuse common patterns for better quantization
    model_fused = torch.ao.quantization.fuse_modules(model, [
        ["features.0.0", "features.0.1", "features.0.2"],
        # ... add all fuseable patterns
    ])

    # Configure quantization
    model_fused.qconfig = torch.ao.quantization.get_default_qconfig("x86")
    torch.ao.quantization.prepare(model_fused, inplace=True)

    # Calibrate with representative data
    with torch.no_grad():
        for inputs, _ in calibration_loader:
            model_fused(inputs)

    # Convert to quantized model
    quantized = torch.ao.quantization.convert(model_fused, inplace=False)

    return quantized

quantized_student = quantize_model(student, calibration_loader)
print(f"Quantized: {model_size_mb(quantized_student):.1f} MB")
# Quantized: ~2.5 MB

Stage 5: Export for Deployment

# Option A: TorchScript
scripted = torch.jit.script(quantized_student)
scripted.save("compressed_model.pt")

# Option B: ONNX export
dummy = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    student,  # Use float model for ONNX (quantize in ONNX Runtime)
    dummy,
    "compressed_model.onnx",
    opset_version=13,
    input_names=["input"],
    output_names=["output"]
)

# Option C: TFLite (via ONNX → TF → TFLite)
# Best for mobile and microcontroller deployment

ONNX Runtime Quantization

For framework-agnostic INT8 deployment:

from onnxruntime.quantization import (
    quantize_static,
    quantize_dynamic,
    CalibrationDataReader,
    QuantFormat,
    QuantType
)

class ImageNetCalibrationReader(CalibrationDataReader):
    def __init__(self, calibration_dir, num_samples=100):
        self.data = self._load_data(calibration_dir, num_samples)
        self.idx = 0

    def get_next(self):
        if self.idx >= len(self.data):
            return None
        sample = {"input": self.data[self.idx]}
        self.idx += 1
        return sample

    def _load_data(self, path, n):
        # Load and preprocess calibration images
        return [preprocess(img) for img in load_images(path)[:n]]

# Static quantization (best accuracy)
quantize_static(
    model_input="compressed_model.onnx",
    model_output="compressed_model_int8.onnx",
    calibration_data_reader=ImageNetCalibrationReader("cal_data/"),
    quant_format=QuantFormat.QDQ,
    weight_type=QuantType.QInt8,
    activation_type=QuantType.QInt8
)

# Dynamic quantization (no calibration needed)
quantize_dynamic(
    model_input="compressed_model.onnx",
    model_output="compressed_model_dynamic.onnx",
    weight_type=QuantType.QInt8
)

Low-Rank Factorization

Decompose large weight matrices into products of smaller matrices:

import torch
from torch import nn

def low_rank_decomposition(weight, rank):
    """Decompose weight matrix using SVD."""
    U, S, Vh = torch.linalg.svd(weight, full_matrices=False)

    # Keep top-k singular values
    U_k = U[:, :rank]
    S_k = torch.diag(S[:rank])
    Vh_k = Vh[:rank, :]

    # Two smaller matrices: A = U_k @ sqrt(S_k), B = sqrt(S_k) @ Vh_k
    sqrt_S = torch.diag(torch.sqrt(S[:rank]))
    A = U_k @ sqrt_S   # (out_features, rank)
    B = sqrt_S @ Vh_k   # (rank, in_features)

    return A, B

def replace_linear_with_low_rank(model, rank_ratio=0.5):
    """Replace Linear layers with low-rank approximations."""
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            in_f, out_f = module.in_features, module.out_features
            rank = int(min(in_f, out_f) * rank_ratio)

            # Only decompose if it saves parameters
            original_params = in_f * out_f
            new_params = in_f * rank + rank * out_f
            if new_params >= original_params:
                continue

            A, B = low_rank_decomposition(module.weight.data, rank)

            # Replace with two smaller linear layers
            replacement = nn.Sequential(
                nn.Linear(in_f, rank, bias=False),
                nn.Linear(rank, out_f, bias=module.bias is not None)
            )
            replacement[0].weight.data = B
            replacement[1].weight.data = A
            if module.bias is not None:
                replacement[1].bias.data = module.bias.data

            setattr(model, name, replacement)
            savings = (1 - new_params / original_params) * 100
            print(f"{name}: rank {rank}, {savings:.0f}% parameter reduction")

        else:
            replace_linear_with_low_rank(module, rank_ratio)

    return model

Comprehensive Benchmarking Framework

import time
import numpy as np

class CompressionBenchmark:
    def __init__(self, val_loader, device="cuda"):
        self.val_loader = val_loader
        self.device = device

    def evaluate(self, model, name="model"):
        """Complete evaluation: accuracy, size, latency."""
        model.eval()

        # Accuracy
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, labels in self.val_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
        accuracy = correct / total

        # Size
        size_mb = model_size_mb(model)
        nonzero, total_params = count_nonzero(model)
        sparsity = 1 - nonzero / total_params

        # Latency
        dummy = torch.randn(1, 3, 224, 224).to(self.device)
        model(dummy)  # Warmup
        times = []
        with torch.no_grad():
            for _ in range(100):
                start = time.perf_counter()
                model(dummy)
                if self.device == "cuda":
                    torch.cuda.synchronize()
                times.append(time.perf_counter() - start)

        result = {
            "name": name,
            "accuracy": accuracy,
            "size_mb": size_mb,
            "params_m": total_params / 1e6,
            "nonzero_m": nonzero / 1e6,
            "sparsity": sparsity,
            "latency_ms": np.mean(times) * 1000,
            "latency_p95_ms": np.percentile(times, 95) * 1000,
        }

        print(f"\n{'='*50}")
        print(f"  {name}")
        print(f"{'='*50}")
        print(f"  Accuracy:  {accuracy:.2%}")
        print(f"  Size:      {size_mb:.1f} MB")
        print(f"  Params:    {total_params/1e6:.1f}M ({nonzero/1e6:.1f}M non-zero)")
        print(f"  Sparsity:  {sparsity:.1%}")
        print(f"  Latency:   {np.mean(times)*1000:.2f}ms (P95: {np.percentile(times,95)*1000:.2f}ms)")

        return result

# Run complete pipeline benchmark
benchmark = CompressionBenchmark(val_loader)

results = []
results.append(benchmark.evaluate(teacher, "Teacher (ResNet-50)"))
results.append(benchmark.evaluate(student, "Distilled (MobileNet V3)"))
results.append(benchmark.evaluate(pruned_student, "Distilled + Pruned"))
results.append(benchmark.evaluate(quantized_student, "Distilled + Pruned + Quantized"))

Expected Pipeline Results

StageSizeAccuracyLatencyCompression
Teacher (ResNet-50)97.8 MB80.4%12.3ms
Distilled (MobileNet V3)9.7 MB67.5%2.1ms10×
+ Structured Pruning9.7 MB*66.8%1.4ms10×
+ INT8 Quantization2.5 MB66.2%0.8ms39×
+ Weight Clustering~1.5 MB65.9%0.8ms65×

*Pruned model has same file size until converted to sparse format or architecture is rebuilt.

The one thing to remember: A production compression pipeline in Python chains distillation → pruning → quantization, measuring accuracy, size, and latency at each stage — where the cumulative effect achieves 40-100× compression, but the specific combination and aggressiveness at each stage must be tuned to your deployment target’s hardware capabilities and your minimum accuracy threshold.

pythonmachine-learningmodel-optimization

See Also