PyTorch Quantization — Deep Dive
PyTorch Quantization APIs: Legacy vs Modern
PyTorch has two quantization systems:
- Eager mode (
torch.quantization) — the original API, wraps modules manually - PT2E (
torch.ao.quantization) — the modern graph-based API usingtorch.exportandtorch.compile
PT2E is the recommended path forward. It works with the compiler stack, supports more backends, and handles complex model patterns that eager mode cannot.
Dynamic Quantization (Eager Mode)
The quickest way to quantize — one function call, no calibration:
import torch
from torch.ao.quantization import quantize_dynamic
model = MyTransformerModel()
model.eval()
# Quantize Linear and LSTM layers to INT8
quantized_model = quantize_dynamic(
model,
qconfig_spec={torch.nn.Linear, torch.nn.LSTM},
dtype=torch.qint8,
)
# Compare sizes
import os
torch.save(model.state_dict(), "/tmp/fp32.pt")
torch.save(quantized_model.state_dict(), "/tmp/int8.pt")
print(f"FP32: {os.path.getsize('/tmp/fp32.pt') / 1e6:.1f} MB")
print(f"INT8: {os.path.getsize('/tmp/int8.pt') / 1e6:.1f} MB")
Dynamic quantization keeps activations in FP32 during computation and quantizes them on-the-fly for the quantized linear operations. This avoids calibration entirely but leaves performance on the table compared to static quantization.
Static Quantization with Calibration
Static quantization pre-computes activation ranges using a calibration dataset:
import torch
from torch.ao.quantization import (
get_default_qconfig_mapping,
prepare,
convert,
)
model = MyCNNModel()
model.eval()
# Step 1: Configure quantization
qconfig_mapping = get_default_qconfig_mapping("x86") # or "qnnpack" for ARM
# Step 2: Prepare — inserts observer modules
prepared_model = prepare(model, qconfig_mapping, example_inputs=(sample_input,))
# Step 3: Calibrate — run representative data through the model
with torch.no_grad():
for batch in calibration_loader:
prepared_model(batch)
# Step 4: Convert — replaces FP32 ops with INT8 ops
quantized_model = convert(prepared_model)
Choosing Calibration Strategy
The observer determines how activation ranges are computed:
from torch.ao.quantization.observer import (
MinMaxObserver, # Simple min/max of observed values
MovingAverageMinMaxObserver, # Smoothed min/max (better for varying ranges)
HistogramObserver, # Histogram-based, minimizes quantization error
PerChannelMinMaxObserver, # Per-channel for weights (higher accuracy)
)
HistogramObserver typically gives the best accuracy because it optimizes the scale/zero-point to minimize overall quantization error rather than just covering the full range. MinMaxObserver is faster but vulnerable to outliers — a single extreme value stretches the range, reducing precision for normal values.
Quantization-Aware Training (QAT)
QAT inserts fake quantization nodes during training that simulate INT8 rounding:
import torch
from torch.ao.quantization import (
get_default_qat_qconfig_mapping,
prepare_qat,
convert,
)
model = MyCNNModel()
model.train() # Must be in training mode for QAT
qconfig_mapping = get_default_qat_qconfig_mapping("x86")
prepared_model = prepare_qat(model, qconfig_mapping,
example_inputs=(sample_input,))
# Fine-tune with fake quantization active
optimizer = torch.optim.SGD(prepared_model.parameters(), lr=1e-4, momentum=0.9)
for epoch in range(5): # Typically 5-10 epochs is enough
for data, target in train_loader:
optimizer.zero_grad()
output = prepared_model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Convert to actual quantized model
prepared_model.eval()
quantized_model = convert(prepared_model)
The fake quantization nodes compute: x_q = round(x / scale) * scale. During forward, values are rounded to simulate INT8. During backward, gradients flow through using the Straight-Through Estimator (STE) — treating the rounding as identity for gradient purposes.
INT4 Weight Quantization for LLMs
For large language models, 4-bit weight quantization is standard practice:
import torch
from torchao.quantization import quantize_, int4_weight_only
model = load_llm_model()
model.eval()
# Apply INT4 weight-only quantization
quantize_(model, int4_weight_only(group_size=128))
# Inference — activations stay in FP16/BF16
with torch.no_grad():
output = model(input_ids)
Group quantization (group_size=128) quantizes weights in groups of 128 elements, each with their own scale factor. This dramatically improves accuracy over per-tensor quantization because different groups can have different ranges.
| Configuration | Model Size (7B params) | Perplexity Impact |
|---|---|---|
| FP16 (baseline) | 14 GB | 0 (reference) |
| INT8 weight-only | 7 GB | +0.01-0.05 |
| INT4 group=128 | 3.5 GB | +0.1-0.3 |
| INT4 group=32 | 4 GB | +0.05-0.15 |
Mixed-Precision Quantization
Not all layers respond equally to quantization. Sensitive layers (first/last layers, attention projections) can stay in higher precision:
from torch.ao.quantization import QConfigMapping
qconfig_mapping = QConfigMapping() \
.set_global(default_int8_qconfig) \
.set_module_name("encoder.layer.0", None) \ # Keep first layer FP32
.set_module_name("classifier", None) # Keep classifier FP32
Automatic mixed-precision quantization tools (like torch.ao.quantization.quantize_pt2e with sensitivity analysis) can determine optimal per-layer precision by measuring accuracy impact of quantizing each layer individually.
Benchmarking Quantized Models
Always benchmark on your target hardware:
import torch
import time
def benchmark(model, input_tensor, warmup=10, iterations=100):
# Warmup
for _ in range(warmup):
model(input_tensor)
if input_tensor.is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
model(input_tensor)
if input_tensor.is_cuda:
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
return elapsed / iterations * 1000 # ms per inference
fp32_ms = benchmark(fp32_model, sample)
int8_ms = benchmark(int8_model, sample)
print(f"FP32: {fp32_ms:.2f} ms | INT8: {int8_ms:.2f} ms | "
f"Speedup: {fp32_ms/int8_ms:.2f}×")
Deployment Considerations
ONNX Runtime: Export quantized models to ONNX for cross-platform deployment. ONNX Runtime’s quantization tools can also quantize during export.
TensorRT: For NVIDIA GPUs, TensorRT achieves the highest INT8 performance through kernel fusion and hardware-specific optimizations. PyTorch models can be converted via torch2trt or exported to ONNX first.
Mobile (iOS/Android): Use torch.utils.mobile_optimizer.optimize_for_mobile() after quantization. For iOS, Core ML conversion handles quantization natively.
Debugging Quantization Accuracy Loss
When accuracy drops unacceptably:
# Compare layer-by-layer outputs between FP32 and quantized
def compare_activations(fp32_model, quant_model, test_input):
hooks_fp32, hooks_quant = {}, {}
def make_hook(storage, name):
def hook(module, input, output):
storage[name] = output.detach().float()
return hook
for name, module in fp32_model.named_modules():
module.register_forward_hook(make_hook(hooks_fp32, name))
for name, module in quant_model.named_modules():
module.register_forward_hook(make_hook(hooks_quant, name))
fp32_model(test_input)
quant_model(test_input)
for name in hooks_fp32:
if name in hooks_quant:
diff = (hooks_fp32[name] - hooks_quant[name]).abs().mean()
print(f"{name}: mean abs diff = {diff:.6f}")
Layers with large differences are candidates for keeping in higher precision or applying QAT.
The one thing to remember: Quantization is a spectrum from zero-effort dynamic quantization to full QAT — the key is measuring accuracy loss per layer, choosing the right precision for each, and always benchmarking on target hardware.
See Also
- Python Hyperparameter Tuning Learn why adjusting the dials on a computer's learning recipe makes predictions way better.
- Python Knowledge Distillation How a big expert AI teaches a tiny student AI to be almost as smart — like a professor writing a cheat sheet for an exam.
- Python Model Compression Methods All the ways Python developers shrink massive AI models to fit on phones and tiny devices — like packing for a trip with a carry-on bag.
- Python Model Pruning Techniques Why cutting away parts of an AI's brain can make it faster without making it dumber.
- Python Neural Architecture Search How AI designs its own brain structure — like a robot architect building the perfect house by trying thousands of floor plans.