Model Pruning Techniques in Python — Deep Dive

PyTorch Unstructured Pruning

PyTorch provides pruning utilities in torch.nn.utils.prune:

Basic Magnitude Pruning

import torch
import torch.nn.utils.prune as prune

model = torchvision.models.resnet50(pretrained=True)

# Prune 50% of weights in a specific layer (by L1 magnitude)
prune.l1_unstructured(model.layer1[0].conv1, name="weight", amount=0.5)

# Check: weight is now a property computed from weight_orig and weight_mask
print(model.layer1[0].conv1.weight_orig.shape)  # Original weights
print(model.layer1[0].conv1.weight_mask.shape)   # Binary mask
print(f"Sparsity: {100 * (1 - model.layer1[0].conv1.weight_mask.sum() / model.layer1[0].conv1.weight_mask.numel()):.1f}%")

Global Unstructured Pruning

Prune across all layers simultaneously, letting the algorithm remove the globally least important weights:

parameters_to_prune = []
for name, module in model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        parameters_to_prune.append((module, "weight"))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.8  # Remove 80% of all weights globally
)

# Verify global sparsity
total_zeros = 0
total_params = 0
for module, _ in parameters_to_prune:
    total_zeros += (module.weight == 0).sum().item()
    total_params += module.weight.numel()

print(f"Global sparsity: {100 * total_zeros / total_params:.1f}%")

Making Pruning Permanent

By default, PyTorch stores pruning as a mask. To permanently remove zeros and reduce storage:

for module, param_name in parameters_to_prune:
    prune.remove(module, param_name)

# Now module.weight is a regular parameter with zeros baked in
# Save the sparse model
torch.save(model.state_dict(), "pruned_model.pth")

Iterative Pruning with Fine-Tuning

The most effective approach — prune gradually while retraining:

import torch
import torch.nn.utils.prune as prune
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR

def get_sparsity(model):
    zeros = 0
    total = 0
    for name, param in model.named_parameters():
        if "weight" in name:
            zeros += (param == 0).sum().item()
            total += param.numel()
    return zeros / total

def iterative_prune(model, train_loader, val_loader, target_sparsity=0.9,
                    num_rounds=10, finetune_epochs=5):
    """Iterative magnitude pruning with fine-tuning."""

    # Calculate per-round pruning rate
    # To reach 90% in 10 rounds: each round removes
    # 1 - (1 - 0.9)^(1/10) ≈ 20.6% of remaining weights
    per_round_amount = 1 - (1 - target_sparsity) ** (1 / num_rounds)

    parameters_to_prune = [
        (m, "weight") for m in model.modules()
        if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear))
    ]

    for round_idx in range(num_rounds):
        # Prune
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=per_round_amount
        )

        current_sparsity = get_sparsity(model)
        print(f"Round {round_idx + 1}: sparsity = {current_sparsity:.1%}")

        # Fine-tune
        optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
        scheduler = CosineAnnealingLR(optimizer, T_max=finetune_epochs)

        for epoch in range(finetune_epochs):
            model.train()
            for inputs, targets in train_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = torch.nn.functional.cross_entropy(outputs, targets)
                loss.backward()
                optimizer.step()
            scheduler.step()

        # Evaluate
        accuracy = evaluate(model, val_loader)
        print(f"  Accuracy after fine-tune: {accuracy:.2%}")

    # Make pruning permanent
    for module, name in parameters_to_prune:
        prune.remove(module, name)

    return model

Structured Pruning in PyTorch

Remove entire channels/filters for real hardware speedups:

def structured_channel_pruning(model, amount=0.3):
    """Prune channels from Conv2d layers by L2-norm of filters."""
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(
                module,
                name="weight",
                amount=amount,
                n=2,      # L2 norm
                dim=0     # Prune along output channel dimension
            )

    return model

# After structured pruning, rebuild the model with smaller layers
def rebuild_pruned_model(model):
    """Create a new model with actually smaller layers."""
    new_model = create_model_architecture()  # Your model factory

    for (name, old_module), (_, new_module) in zip(
        model.named_modules(), new_model.named_modules()
    ):
        if isinstance(old_module, torch.nn.Conv2d):
            # Find which output channels survived
            mask = old_module.weight_mask
            surviving_channels = mask.sum(dim=(1, 2, 3)) > 0

            # Copy surviving weights
            new_weight = old_module.weight_orig[surviving_channels]
            # ... rebuild layer with fewer channels

    return new_model

Using torch.nn.utils.parametrize for Custom Criteria

import torch.nn.utils.parametrize as parametrize

class TaylorPruningMask(torch.nn.Module):
    """Prune by Taylor expansion importance (weight × gradient)."""

    def __init__(self, weight_shape, amount=0.5):
        super().__init__()
        self.register_buffer("mask", torch.ones(weight_shape))
        self.amount = amount
        self.importance_scores = None

    def compute_importance(self, weight):
        """Call after backward pass."""
        if weight.grad is not None:
            self.importance_scores = (weight * weight.grad).abs()

    def update_mask(self):
        if self.importance_scores is None:
            return
        flat = self.importance_scores.flatten()
        k = int(flat.numel() * self.amount)
        threshold = flat.kthsmallest(k)
        self.mask = (self.importance_scores >= threshold).float()

    def forward(self, weight):
        return weight * self.mask

TensorFlow Model Optimization Toolkit

TensorFlow provides pruning through the tensorflow-model-optimization package:

import tensorflow as tf
import tensorflow_model_optimization as tfmot

# Load a trained Keras model
base_model = tf.keras.models.load_model("trained_model.h5")

# Configure pruning schedule
pruning_params = {
    "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.30,
        final_sparsity=0.90,
        begin_step=0,
        end_step=10000,
        frequency=100  # Update pruning mask every 100 steps
    )
}

# Apply pruning to the model
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
    base_model,
    **pruning_params
)

# Compile and train (fine-tune)
pruned_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir="logs/pruning")
]

pruned_model.fit(
    train_dataset,
    epochs=10,
    validation_data=val_dataset,
    callbacks=callbacks
)

# Strip pruning wrappers for deployment
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
final_model.save("pruned_final_model.h5")

Combining Pruning with Quantization in TF

# Prune → Quantize → Deploy
pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

# Convert to TFLite with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Pruned + quantized model: ~20-40× smaller than original
original_size = os.path.getsize("trained_model.h5")
pruned_quant_size = len(tflite_model)
print(f"Compression: {original_size / pruned_quant_size:.1f}×")

Lottery Ticket Hypothesis Implementation

def find_winning_ticket(model_factory, train_loader, val_loader,
                        prune_rate=0.2, rounds=15):
    """Iterative magnitude pruning with weight rewinding."""

    # Train full model
    model = model_factory().cuda()
    initial_state = {k: v.clone() for k, v in model.state_dict().items()}
    train_model(model, train_loader, epochs=20)

    mask = {name: torch.ones_like(param)
            for name, param in model.named_parameters()
            if "weight" in name}

    for round_idx in range(rounds):
        # Identify smallest weights
        all_weights = []
        for name, param in model.named_parameters():
            if name in mask:
                alive = param[mask[name].bool()].abs()
                all_weights.append(alive.flatten())

        all_weights = torch.cat(all_weights)
        threshold = all_weights.kthsmallest(
            int(all_weights.numel() * prune_rate)
        )

        # Update mask
        for name, param in model.named_parameters():
            if name in mask:
                mask[name] *= (param.abs() >= threshold).float()

        sparsity = 1 - sum(m.sum() for m in mask.values()) / sum(m.numel() for m in mask.values())
        print(f"Round {round_idx + 1}: sparsity = {sparsity:.1%}")

        # Rewind to initial weights (key insight of lottery ticket)
        model.load_state_dict(initial_state)

        # Apply mask and retrain
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in mask:
                    param.mul_(mask[name])

        train_model(model, train_loader, epochs=20, mask=mask)
        accuracy = evaluate(model, val_loader)
        print(f"  Accuracy: {accuracy:.2%}")

    return model, mask

Sparse Inference Acceleration

NVIDIA Sparse Tensor Cores (A100+)

# 2:4 structured sparsity — NVIDIA's hardware-native format
# Every group of 4 elements has at most 2 non-zeros

from torch.ao.pruning import WeightNormSparsifier

sparsifier = WeightNormSparsifier(
    sparsity_level=0.5,
    sparse_block_shape=(1, 4),   # 2:4 pattern
    zeros_per_block=2
)

sparsifier.prepare(model, config=[
    {"tensor_fqn": f"{name}.weight"}
    for name, module in model.named_modules()
    if isinstance(module, torch.nn.Linear)
])

# Train with sparsity constraint
for epoch in range(num_epochs):
    train_one_epoch(model, train_loader)
    sparsifier.step()  # Enforce sparsity pattern

sparsifier.squash()  # Make permanent

DeepSparse Engine (CPU Sparse Inference)

# Neural Magic's DeepSparse runs sparse models fast on CPUs
from deepsparse import Engine

# Load a pruned ONNX model
engine = Engine(
    model="pruned_model.onnx",
    batch_size=1,
    num_cores=4
)

# Run inference
output = engine.run([input_data])

# Benchmarking
from deepsparse import benchmark_model

results = benchmark_model(
    "pruned_model.onnx",
    input_shapes=[[1, 3, 224, 224]],
    num_iterations=1000
)
print(f"Throughput: {results['items_per_sec']:.0f} items/sec")
print(f"Latency P50: {results['latency_ms']['p50']:.2f}ms")

Pruning Decision Matrix

ScenarioMethodTarget SparsityNotes
Mobile deploymentStructured channel pruning30-50%Real speedup on standard hardware
Server with A100 GPU2:4 structured50%Hardware-native support
MicrocontrollerUnstructured + quantization90%+Use sparse inference engine
Research/explorationLottery ticket80-95%Best accuracy but expensive
Quick compressionOne-shot magnitude50-70%Fast, minimal engineering

The one thing to remember: Effective pruning in Python combines the right strategy (structured for standard hardware, unstructured for sparse-aware runtimes), iterative prune-retrain cycles with gradual sparsity schedules, and validation that the resulting model actually runs faster on your target deployment hardware — sparsity alone doesn’t guarantee speedup without matching runtime support.

pythonmachine-learningmodel-optimization

See Also