Model Pruning Techniques in Python — Deep Dive
PyTorch Unstructured Pruning
PyTorch provides pruning utilities in torch.nn.utils.prune:
Basic Magnitude Pruning
import torch
import torch.nn.utils.prune as prune
model = torchvision.models.resnet50(pretrained=True)
# Prune 50% of weights in a specific layer (by L1 magnitude)
prune.l1_unstructured(model.layer1[0].conv1, name="weight", amount=0.5)
# Check: weight is now a property computed from weight_orig and weight_mask
print(model.layer1[0].conv1.weight_orig.shape) # Original weights
print(model.layer1[0].conv1.weight_mask.shape) # Binary mask
print(f"Sparsity: {100 * (1 - model.layer1[0].conv1.weight_mask.sum() / model.layer1[0].conv1.weight_mask.numel()):.1f}%")
Global Unstructured Pruning
Prune across all layers simultaneously, letting the algorithm remove the globally least important weights:
parameters_to_prune = []
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
parameters_to_prune.append((module, "weight"))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.8 # Remove 80% of all weights globally
)
# Verify global sparsity
total_zeros = 0
total_params = 0
for module, _ in parameters_to_prune:
total_zeros += (module.weight == 0).sum().item()
total_params += module.weight.numel()
print(f"Global sparsity: {100 * total_zeros / total_params:.1f}%")
Making Pruning Permanent
By default, PyTorch stores pruning as a mask. To permanently remove zeros and reduce storage:
for module, param_name in parameters_to_prune:
prune.remove(module, param_name)
# Now module.weight is a regular parameter with zeros baked in
# Save the sparse model
torch.save(model.state_dict(), "pruned_model.pth")
Iterative Pruning with Fine-Tuning
The most effective approach — prune gradually while retraining:
import torch
import torch.nn.utils.prune as prune
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
def get_sparsity(model):
zeros = 0
total = 0
for name, param in model.named_parameters():
if "weight" in name:
zeros += (param == 0).sum().item()
total += param.numel()
return zeros / total
def iterative_prune(model, train_loader, val_loader, target_sparsity=0.9,
num_rounds=10, finetune_epochs=5):
"""Iterative magnitude pruning with fine-tuning."""
# Calculate per-round pruning rate
# To reach 90% in 10 rounds: each round removes
# 1 - (1 - 0.9)^(1/10) ≈ 20.6% of remaining weights
per_round_amount = 1 - (1 - target_sparsity) ** (1 / num_rounds)
parameters_to_prune = [
(m, "weight") for m in model.modules()
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear))
]
for round_idx in range(num_rounds):
# Prune
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=per_round_amount
)
current_sparsity = get_sparsity(model)
print(f"Round {round_idx + 1}: sparsity = {current_sparsity:.1%}")
# Fine-tune
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, T_max=finetune_epochs)
for epoch in range(finetune_epochs):
model.train()
for inputs, targets in train_loader:
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()
scheduler.step()
# Evaluate
accuracy = evaluate(model, val_loader)
print(f" Accuracy after fine-tune: {accuracy:.2%}")
# Make pruning permanent
for module, name in parameters_to_prune:
prune.remove(module, name)
return model
Structured Pruning in PyTorch
Remove entire channels/filters for real hardware speedups:
def structured_channel_pruning(model, amount=0.3):
"""Prune channels from Conv2d layers by L2-norm of filters."""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.ln_structured(
module,
name="weight",
amount=amount,
n=2, # L2 norm
dim=0 # Prune along output channel dimension
)
return model
# After structured pruning, rebuild the model with smaller layers
def rebuild_pruned_model(model):
"""Create a new model with actually smaller layers."""
new_model = create_model_architecture() # Your model factory
for (name, old_module), (_, new_module) in zip(
model.named_modules(), new_model.named_modules()
):
if isinstance(old_module, torch.nn.Conv2d):
# Find which output channels survived
mask = old_module.weight_mask
surviving_channels = mask.sum(dim=(1, 2, 3)) > 0
# Copy surviving weights
new_weight = old_module.weight_orig[surviving_channels]
# ... rebuild layer with fewer channels
return new_model
Using torch.nn.utils.parametrize for Custom Criteria
import torch.nn.utils.parametrize as parametrize
class TaylorPruningMask(torch.nn.Module):
"""Prune by Taylor expansion importance (weight × gradient)."""
def __init__(self, weight_shape, amount=0.5):
super().__init__()
self.register_buffer("mask", torch.ones(weight_shape))
self.amount = amount
self.importance_scores = None
def compute_importance(self, weight):
"""Call after backward pass."""
if weight.grad is not None:
self.importance_scores = (weight * weight.grad).abs()
def update_mask(self):
if self.importance_scores is None:
return
flat = self.importance_scores.flatten()
k = int(flat.numel() * self.amount)
threshold = flat.kthsmallest(k)
self.mask = (self.importance_scores >= threshold).float()
def forward(self, weight):
return weight * self.mask
TensorFlow Model Optimization Toolkit
TensorFlow provides pruning through the tensorflow-model-optimization package:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
# Load a trained Keras model
base_model = tf.keras.models.load_model("trained_model.h5")
# Configure pruning schedule
pruning_params = {
"pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.30,
final_sparsity=0.90,
begin_step=0,
end_step=10000,
frequency=100 # Update pruning mask every 100 steps
)
}
# Apply pruning to the model
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
base_model,
**pruning_params
)
# Compile and train (fine-tune)
pruned_model.compile(
optimizer=tf.keras.optimizers.Adam(1e-4),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir="logs/pruning")
]
pruned_model.fit(
train_dataset,
epochs=10,
validation_data=val_dataset,
callbacks=callbacks
)
# Strip pruning wrappers for deployment
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
final_model.save("pruned_final_model.h5")
Combining Pruning with Quantization in TF
# Prune → Quantize → Deploy
pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
# Convert to TFLite with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Pruned + quantized model: ~20-40× smaller than original
original_size = os.path.getsize("trained_model.h5")
pruned_quant_size = len(tflite_model)
print(f"Compression: {original_size / pruned_quant_size:.1f}×")
Lottery Ticket Hypothesis Implementation
def find_winning_ticket(model_factory, train_loader, val_loader,
prune_rate=0.2, rounds=15):
"""Iterative magnitude pruning with weight rewinding."""
# Train full model
model = model_factory().cuda()
initial_state = {k: v.clone() for k, v in model.state_dict().items()}
train_model(model, train_loader, epochs=20)
mask = {name: torch.ones_like(param)
for name, param in model.named_parameters()
if "weight" in name}
for round_idx in range(rounds):
# Identify smallest weights
all_weights = []
for name, param in model.named_parameters():
if name in mask:
alive = param[mask[name].bool()].abs()
all_weights.append(alive.flatten())
all_weights = torch.cat(all_weights)
threshold = all_weights.kthsmallest(
int(all_weights.numel() * prune_rate)
)
# Update mask
for name, param in model.named_parameters():
if name in mask:
mask[name] *= (param.abs() >= threshold).float()
sparsity = 1 - sum(m.sum() for m in mask.values()) / sum(m.numel() for m in mask.values())
print(f"Round {round_idx + 1}: sparsity = {sparsity:.1%}")
# Rewind to initial weights (key insight of lottery ticket)
model.load_state_dict(initial_state)
# Apply mask and retrain
with torch.no_grad():
for name, param in model.named_parameters():
if name in mask:
param.mul_(mask[name])
train_model(model, train_loader, epochs=20, mask=mask)
accuracy = evaluate(model, val_loader)
print(f" Accuracy: {accuracy:.2%}")
return model, mask
Sparse Inference Acceleration
NVIDIA Sparse Tensor Cores (A100+)
# 2:4 structured sparsity — NVIDIA's hardware-native format
# Every group of 4 elements has at most 2 non-zeros
from torch.ao.pruning import WeightNormSparsifier
sparsifier = WeightNormSparsifier(
sparsity_level=0.5,
sparse_block_shape=(1, 4), # 2:4 pattern
zeros_per_block=2
)
sparsifier.prepare(model, config=[
{"tensor_fqn": f"{name}.weight"}
for name, module in model.named_modules()
if isinstance(module, torch.nn.Linear)
])
# Train with sparsity constraint
for epoch in range(num_epochs):
train_one_epoch(model, train_loader)
sparsifier.step() # Enforce sparsity pattern
sparsifier.squash() # Make permanent
DeepSparse Engine (CPU Sparse Inference)
# Neural Magic's DeepSparse runs sparse models fast on CPUs
from deepsparse import Engine
# Load a pruned ONNX model
engine = Engine(
model="pruned_model.onnx",
batch_size=1,
num_cores=4
)
# Run inference
output = engine.run([input_data])
# Benchmarking
from deepsparse import benchmark_model
results = benchmark_model(
"pruned_model.onnx",
input_shapes=[[1, 3, 224, 224]],
num_iterations=1000
)
print(f"Throughput: {results['items_per_sec']:.0f} items/sec")
print(f"Latency P50: {results['latency_ms']['p50']:.2f}ms")
Pruning Decision Matrix
| Scenario | Method | Target Sparsity | Notes |
|---|---|---|---|
| Mobile deployment | Structured channel pruning | 30-50% | Real speedup on standard hardware |
| Server with A100 GPU | 2:4 structured | 50% | Hardware-native support |
| Microcontroller | Unstructured + quantization | 90%+ | Use sparse inference engine |
| Research/exploration | Lottery ticket | 80-95% | Best accuracy but expensive |
| Quick compression | One-shot magnitude | 50-70% | Fast, minimal engineering |
The one thing to remember: Effective pruning in Python combines the right strategy (structured for standard hardware, unstructured for sparse-aware runtimes), iterative prune-retrain cycles with gradual sparsity schedules, and validation that the resulting model actually runs faster on your target deployment hardware — sparsity alone doesn’t guarantee speedup without matching runtime support.
See Also
- Python Hyperparameter Tuning Learn why adjusting the dials on a computer's learning recipe makes predictions way better.
- Python Knowledge Distillation How a big expert AI teaches a tiny student AI to be almost as smart — like a professor writing a cheat sheet for an exam.
- Python Model Compression Methods All the ways Python developers shrink massive AI models to fit on phones and tiny devices — like packing for a trip with a carry-on bag.
- Python Neural Architecture Search How AI designs its own brain structure — like a robot architect building the perfect house by trying thousands of floor plans.
- Python Pytorch Quantization How shrinking numbers inside an AI model makes it run faster on phones and cheaper servers without losing much accuracy.