PyTorch ONNX Export — Deep Dive

Basic Export with torch.onnx.export

The standard export API traces the model with sample input:

import torch
import torch.nn as nn

class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embed_dim, nhead=8, batch_first=True),
            num_layers=4,
        )
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.embed(x)
        x = self.encoder(x)
        x = x.mean(dim=1)  # Global average pooling
        return self.classifier(x)

model = TextClassifier(30000, 256, 10)
model.eval()

# Sample input matching expected shape
dummy_input = torch.randint(0, 30000, (1, 128))

torch.onnx.export(
    model,
    dummy_input,
    "text_classifier.onnx",
    opset_version=17,
    input_names=["input_ids"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "seq_length"},
        "logits": {0: "batch_size"},
    },
)

Dynamic Axes: Handling Variable Shapes

Without dynamic_axes, the exported model only accepts the exact input shape used during tracing. For real deployment, you need flexibility:

dynamic_axes = {
    "input_ids": {0: "batch", 1: "sequence"},
    "attention_mask": {0: "batch", 1: "sequence"},
    "logits": {0: "batch"},
}

Each entry maps a dimension index to a symbolic name. The ONNX graph uses these symbols instead of fixed sizes, allowing any value at runtime. ONNX Runtime handles the dynamic shapes through its memory planner.

For models with multiple inputs at different dynamic shapes:

torch.onnx.export(
    model,
    (input_ids, attention_mask, token_type_ids),
    "bert.onnx",
    opset_version=17,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch", 1: "seq"},
        "attention_mask": {0: "batch", 1: "seq"},
        "token_type_ids": {0: "batch", 1: "seq"},
        "logits": {0: "batch"},
    },
)

Validating Exported Models

Never deploy without validation. Compare PyTorch and ONNX outputs numerically:

import onnxruntime as ort
import numpy as np

# PyTorch inference
with torch.no_grad():
    pt_output = model(dummy_input).numpy()

# ONNX Runtime inference
session = ort.InferenceSession("text_classifier.onnx")
ort_output = session.run(
    None,
    {"input_ids": dummy_input.numpy()},
)[0]

# Compare
np.testing.assert_allclose(pt_output, ort_output, rtol=1e-5, atol=1e-6)
print("Validation passed: outputs match within tolerance")

For FP16 models, use looser tolerances (rtol=1e-3, atol=1e-3). Check multiple inputs — especially edge cases like empty sequences, maximum-length inputs, and batch size 1.

ONNX Model Inspection and Optimization

Inspect the exported graph before deployment:

import onnx
from onnx import shape_inference

# Load and validate
model_proto = onnx.load("text_classifier.onnx")
onnx.checker.check_model(model_proto)

# Infer shapes for all intermediate tensors
model_proto = shape_inference.infer_shapes(model_proto)

# Print graph summary
print(f"IR version: {model_proto.ir_version}")
print(f"Opset: {model_proto.opset_import[0].version}")
print(f"Nodes: {len(model_proto.graph.node)}")
print(f"Parameters: {sum(np.prod(i.dims) for i in model_proto.graph.initializer)}")

Apply ONNX-level optimizations:

from onnxruntime.transformers import optimizer

optimized_model = optimizer.optimize_model(
    "text_classifier.onnx",
    model_type="bert",  # Enables Transformer-specific fusions
    num_heads=8,
    hidden_size=256,
)
optimized_model.save_model_to_file("text_classifier_optimized.onnx")

Custom Operators

When your model uses operations without ONNX equivalents, register custom symbolic functions:

# Custom PyTorch operation
class MyGELU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x * 0.5 * (1.0 + torch.erf(x / 1.41421356))

    @staticmethod
    def symbolic(g, x):
        # Map to ONNX's built-in Gelu operator (opset 20+)
        return g.op("Gelu", x)

    @staticmethod
    def backward(ctx, grad):
        # Not needed for export, but required for training
        x, = ctx.saved_tensors
        cdf = 0.5 * (1.0 + torch.erf(x / 1.41421356))
        pdf = torch.exp(-0.5 * x ** 2) / 2.5066282
        return grad * (cdf + x * pdf)

For operations with no ONNX equivalent at all, use custom ONNX domains:

@staticmethod
def symbolic(g, x, alpha):
    return g.op("custom_domain::MyCustomOp", x, alpha_f=alpha)

ONNX Runtime can load custom operator libraries to handle these at inference time.

Inference with ONNX Runtime

Configure ONNX Runtime for maximum performance:

import onnxruntime as ort

# CPU with optimization
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.intra_op_num_threads = 4
session_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL

session = ort.InferenceSession(
    "text_classifier_optimized.onnx",
    session_options,
    providers=["CPUExecutionProvider"],
)

# GPU with CUDA and TensorRT
gpu_session = ort.InferenceSession(
    "text_classifier_optimized.onnx",
    providers=[
        ("TensorrtExecutionProvider", {
            "trt_max_workspace_size": 2 * 1024 * 1024 * 1024,
            "trt_fp16_enable": True,
        }),
        "CUDAExecutionProvider",
        "CPUExecutionProvider",
    ],
)

Provider order matters — ORT tries each in order and uses the first available. Placing TensorRT before CUDA before CPU creates a graceful fallback chain.

Batch Inference and Throughput

For high-throughput serving, batch requests and use IO binding to avoid memory copies:

# IO Binding eliminates host-device copies for GPU inference
io_binding = gpu_session.io_binding()

# Pre-allocate GPU tensors
input_tensor = ort.OrtValue.ortvalue_from_numpy(
    input_data, device_type="cuda", device_id=0
)
output_tensor = ort.OrtValue.ortvalue_from_shape_and_type(
    [batch_size, num_classes], np.float32, device_type="cuda", device_id=0
)

io_binding.bind_ortvalue_input("input_ids", input_tensor)
io_binding.bind_ortvalue_output("logits", output_tensor)

gpu_session.run_with_iobinding(io_binding)
result = output_tensor.numpy()  # Only copies at the end

Benchmarking PyTorch vs ONNX Runtime

A systematic comparison for deployment decisions:

import time

def benchmark_session(session, input_dict, warmup=50, runs=200):
    for _ in range(warmup):
        session.run(None, input_dict)

    start = time.perf_counter()
    for _ in range(runs):
        session.run(None, input_dict)
    elapsed = time.perf_counter() - start

    return elapsed / runs * 1000  # ms per inference

input_dict = {"input_ids": dummy_input.numpy()}
ort_ms = benchmark_session(session, input_dict)

# Compare with PyTorch
with torch.no_grad():
    for _ in range(50):
        model(dummy_input)
    start = time.perf_counter()
    for _ in range(200):
        model(dummy_input)
    pt_ms = (time.perf_counter() - start) / 200 * 1000

print(f"PyTorch: {pt_ms:.2f} ms | ORT: {ort_ms:.2f} ms | "
      f"Speedup: {pt_ms/ort_ms:.2f}×")

Typical results: 1.5-3× speedup on CPU, 1.2-2× on GPU (less because GPU is already fast). The CPU improvement is larger because ONNX Runtime’s graph optimizations and operator fusion eliminate Python overhead that dominates CPU execution.

Export Troubleshooting

Unsupported operator errors: Check the opset version. Newer opsets support more operators. Use opset_version=17 or higher. If an operator isn’t supported, implement a custom symbolic or decompose it into supported ops.

Shape inference failures: Add explicit shape annotations or use torch.onnx.export with dynamic_axes. Avoid operations that create tensors with shapes dependent on tensor values (like torch.nonzero).

Numerical mismatches: Common with operations involving reductions (sum, mean) due to floating-point ordering differences. Use atol=1e-5 for FP32 and atol=1e-2 for FP16 comparisons.

The one thing to remember: ONNX export is a deployment pipeline, not just a file format conversion — validation against PyTorch outputs, dynamic axis configuration, and runtime optimization are all essential steps between training and production.

pythonmachine-learningpytorch

See Also

  • Python Pytorch Torchscript How TorchScript lets PyTorch models escape Python and run independently in apps, servers, and devices.
  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.