PyTorch Quantization — Deep Dive

PyTorch Quantization APIs: Legacy vs Modern

PyTorch has two quantization systems:

  1. Eager mode (torch.quantization) — the original API, wraps modules manually
  2. PT2E (torch.ao.quantization) — the modern graph-based API using torch.export and torch.compile

PT2E is the recommended path forward. It works with the compiler stack, supports more backends, and handles complex model patterns that eager mode cannot.

Dynamic Quantization (Eager Mode)

The quickest way to quantize — one function call, no calibration:

import torch
from torch.ao.quantization import quantize_dynamic

model = MyTransformerModel()
model.eval()

# Quantize Linear and LSTM layers to INT8
quantized_model = quantize_dynamic(
    model,
    qconfig_spec={torch.nn.Linear, torch.nn.LSTM},
    dtype=torch.qint8,
)

# Compare sizes
import os
torch.save(model.state_dict(), "/tmp/fp32.pt")
torch.save(quantized_model.state_dict(), "/tmp/int8.pt")
print(f"FP32: {os.path.getsize('/tmp/fp32.pt') / 1e6:.1f} MB")
print(f"INT8: {os.path.getsize('/tmp/int8.pt') / 1e6:.1f} MB")

Dynamic quantization keeps activations in FP32 during computation and quantizes them on-the-fly for the quantized linear operations. This avoids calibration entirely but leaves performance on the table compared to static quantization.

Static Quantization with Calibration

Static quantization pre-computes activation ranges using a calibration dataset:

import torch
from torch.ao.quantization import (
    get_default_qconfig_mapping,
    prepare,
    convert,
)

model = MyCNNModel()
model.eval()

# Step 1: Configure quantization
qconfig_mapping = get_default_qconfig_mapping("x86")  # or "qnnpack" for ARM

# Step 2: Prepare — inserts observer modules
prepared_model = prepare(model, qconfig_mapping, example_inputs=(sample_input,))

# Step 3: Calibrate — run representative data through the model
with torch.no_grad():
    for batch in calibration_loader:
        prepared_model(batch)

# Step 4: Convert — replaces FP32 ops with INT8 ops
quantized_model = convert(prepared_model)

Choosing Calibration Strategy

The observer determines how activation ranges are computed:

from torch.ao.quantization.observer import (
    MinMaxObserver,           # Simple min/max of observed values
    MovingAverageMinMaxObserver,  # Smoothed min/max (better for varying ranges)
    HistogramObserver,        # Histogram-based, minimizes quantization error
    PerChannelMinMaxObserver, # Per-channel for weights (higher accuracy)
)

HistogramObserver typically gives the best accuracy because it optimizes the scale/zero-point to minimize overall quantization error rather than just covering the full range. MinMaxObserver is faster but vulnerable to outliers — a single extreme value stretches the range, reducing precision for normal values.

Quantization-Aware Training (QAT)

QAT inserts fake quantization nodes during training that simulate INT8 rounding:

import torch
from torch.ao.quantization import (
    get_default_qat_qconfig_mapping,
    prepare_qat,
    convert,
)

model = MyCNNModel()
model.train()  # Must be in training mode for QAT

qconfig_mapping = get_default_qat_qconfig_mapping("x86")
prepared_model = prepare_qat(model, qconfig_mapping,
                              example_inputs=(sample_input,))

# Fine-tune with fake quantization active
optimizer = torch.optim.SGD(prepared_model.parameters(), lr=1e-4, momentum=0.9)

for epoch in range(5):  # Typically 5-10 epochs is enough
    for data, target in train_loader:
        optimizer.zero_grad()
        output = prepared_model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Convert to actual quantized model
prepared_model.eval()
quantized_model = convert(prepared_model)

The fake quantization nodes compute: x_q = round(x / scale) * scale. During forward, values are rounded to simulate INT8. During backward, gradients flow through using the Straight-Through Estimator (STE) — treating the rounding as identity for gradient purposes.

INT4 Weight Quantization for LLMs

For large language models, 4-bit weight quantization is standard practice:

import torch
from torchao.quantization import quantize_, int4_weight_only

model = load_llm_model()
model.eval()

# Apply INT4 weight-only quantization
quantize_(model, int4_weight_only(group_size=128))

# Inference — activations stay in FP16/BF16
with torch.no_grad():
    output = model(input_ids)

Group quantization (group_size=128) quantizes weights in groups of 128 elements, each with their own scale factor. This dramatically improves accuracy over per-tensor quantization because different groups can have different ranges.

ConfigurationModel Size (7B params)Perplexity Impact
FP16 (baseline)14 GB0 (reference)
INT8 weight-only7 GB+0.01-0.05
INT4 group=1283.5 GB+0.1-0.3
INT4 group=324 GB+0.05-0.15

Mixed-Precision Quantization

Not all layers respond equally to quantization. Sensitive layers (first/last layers, attention projections) can stay in higher precision:

from torch.ao.quantization import QConfigMapping

qconfig_mapping = QConfigMapping() \
    .set_global(default_int8_qconfig) \
    .set_module_name("encoder.layer.0", None) \  # Keep first layer FP32
    .set_module_name("classifier", None)          # Keep classifier FP32

Automatic mixed-precision quantization tools (like torch.ao.quantization.quantize_pt2e with sensitivity analysis) can determine optimal per-layer precision by measuring accuracy impact of quantizing each layer individually.

Benchmarking Quantized Models

Always benchmark on your target hardware:

import torch
import time

def benchmark(model, input_tensor, warmup=10, iterations=100):
    # Warmup
    for _ in range(warmup):
        model(input_tensor)

    if input_tensor.is_cuda:
        torch.cuda.synchronize()

    start = time.perf_counter()
    for _ in range(iterations):
        model(input_tensor)

    if input_tensor.is_cuda:
        torch.cuda.synchronize()

    elapsed = time.perf_counter() - start
    return elapsed / iterations * 1000  # ms per inference

fp32_ms = benchmark(fp32_model, sample)
int8_ms = benchmark(int8_model, sample)
print(f"FP32: {fp32_ms:.2f} ms | INT8: {int8_ms:.2f} ms | "
      f"Speedup: {fp32_ms/int8_ms:.2f}×")

Deployment Considerations

ONNX Runtime: Export quantized models to ONNX for cross-platform deployment. ONNX Runtime’s quantization tools can also quantize during export.

TensorRT: For NVIDIA GPUs, TensorRT achieves the highest INT8 performance through kernel fusion and hardware-specific optimizations. PyTorch models can be converted via torch2trt or exported to ONNX first.

Mobile (iOS/Android): Use torch.utils.mobile_optimizer.optimize_for_mobile() after quantization. For iOS, Core ML conversion handles quantization natively.

Debugging Quantization Accuracy Loss

When accuracy drops unacceptably:

# Compare layer-by-layer outputs between FP32 and quantized
def compare_activations(fp32_model, quant_model, test_input):
    hooks_fp32, hooks_quant = {}, {}

    def make_hook(storage, name):
        def hook(module, input, output):
            storage[name] = output.detach().float()
        return hook

    for name, module in fp32_model.named_modules():
        module.register_forward_hook(make_hook(hooks_fp32, name))
    for name, module in quant_model.named_modules():
        module.register_forward_hook(make_hook(hooks_quant, name))

    fp32_model(test_input)
    quant_model(test_input)

    for name in hooks_fp32:
        if name in hooks_quant:
            diff = (hooks_fp32[name] - hooks_quant[name]).abs().mean()
            print(f"{name}: mean abs diff = {diff:.6f}")

Layers with large differences are candidates for keeping in higher precision or applying QAT.

The one thing to remember: Quantization is a spectrum from zero-effort dynamic quantization to full QAT — the key is measuring accuracy loss per layer, choosing the right precision for each, and always benchmarking on target hardware.

pythonmachine-learningpytorch

See Also