PyTorch TorchScript — Deep Dive
Tracing in Detail
Tracing converts a model by executing it and recording operations:
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.bn = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
ConvBlock(3, 32),
nn.MaxPool2d(2),
ConvBlock(32, 64),
nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x):
x = self.features(x)
x = x.flatten(1)
return self.classifier(x)
model = SimpleCNN()
model.eval()
# Trace with example input
example = torch.randn(1, 3, 32, 32)
traced = torch.jit.trace(model, example)
# Save as standalone artifact
traced.save("simple_cnn.pt")
# Verify output matches
with torch.no_grad():
orig_out = model(example)
traced_out = traced(example)
assert torch.allclose(orig_out, traced_out, atol=1e-6)
Tracing warnings appear when the tracer detects potential issues — pay attention to them:
TracerWarning: Converting a tensor to a Python boolean might cause
the trace to be incorrect. We can't record the data flow of Python
values, so this value will be treated as a constant.
This warning means your model has control flow that tracing can’t capture. Switch to scripting for those modules.
Scripting in Detail
Scripting compiles Python source code to TorchScript IR:
class DynamicRouter(nn.Module):
"""Model with data-dependent control flow — must be scripted."""
def __init__(self, d_model: int, num_experts: int):
super().__init__()
self.gate = nn.Linear(d_model, num_experts)
self.experts = nn.ModuleList([
nn.Linear(d_model, d_model) for _ in range(num_experts)
])
self.threshold = 0.5
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Gate scores determine which expert processes each sample
scores = torch.softmax(self.gate(x), dim=-1)
max_score, expert_idx = scores.max(dim=-1)
outputs = torch.zeros_like(x)
for i, expert in enumerate(self.experts):
mask = expert_idx == i
if mask.any():
outputs[mask] = expert(x[mask])
return outputs
model = DynamicRouter(256, 4)
model.eval()
# Script — captures the for loop and conditional
scripted = torch.jit.script(model)
scripted.save("dynamic_router.pt")
Type Annotations for TorchScript
TorchScript requires explicit types. Common patterns:
from typing import List, Optional, Tuple, Dict
@torch.jit.script
def process_outputs(
logits: torch.Tensor,
labels: Optional[torch.Tensor] = None,
top_k: int = 5,
) -> Tuple[torch.Tensor, List[int]]:
probs = torch.softmax(logits, dim=-1)
values, indices = probs.topk(top_k, dim=-1)
# Build list of top predictions for first sample
top_classes: List[int] = []
for i in range(top_k):
top_classes.append(indices[0, i].item())
if labels is not None:
return probs, top_classes
else:
return probs, top_classes
Types that TorchScript supports: int, float, bool, str, Tensor, List[T], Dict[K, V], Tuple[T, ...], Optional[T], and @torch.jit.script decorated classes.
Hybrid Tracing and Scripting
For complex models, combine both approaches:
class Encoder(nn.Module):
"""Simple encoder — no control flow, can be traced."""
def __init__(self, d_model):
super().__init__()
self.layers = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model, 8, batch_first=True),
num_layers=6,
)
def forward(self, x):
return self.layers(x)
class ConditionalDecoder(nn.Module):
"""Decoder with control flow — must be scripted."""
def __init__(self, d_model, vocab_size):
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)
self.temperature: float = 1.0
@torch.jit.export
def set_temperature(self, t: float):
self.temperature = t
def forward(self, x: torch.Tensor, greedy: bool = True) -> torch.Tensor:
logits = self.proj(x) / self.temperature
if greedy:
return logits.argmax(dim=-1)
else:
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
probs.size(0), probs.size(1)
)
class FullModel(nn.Module):
def __init__(self, d_model, vocab_size):
super().__init__()
# Trace the encoder (no control flow)
self.encoder = torch.jit.trace(
Encoder(d_model),
torch.randn(1, 16, d_model),
)
# Script the decoder (has control flow)
self.decoder = torch.jit.script(ConditionalDecoder(d_model, vocab_size))
def forward(self, x: torch.Tensor, greedy: bool = True) -> torch.Tensor:
encoded = self.encoder(x)
return self.decoder(encoded, greedy)
The @torch.jit.export decorator makes set_temperature available on the scripted module — methods not decorated with forward or @torch.jit.export are not accessible after scripting.
Loading in C++ with LibTorch
The primary production use case — running models in C++ without Python:
#include <torch/script.h>
#include <iostream>
int main() {
// Load the TorchScript model
torch::jit::script::Module model;
try {
model = torch::jit::load("simple_cnn.pt");
} catch (const c10::Error& e) {
std::cerr << "Error loading model: " << e.what() << std::endl;
return 1;
}
model.eval();
// Create input tensor
auto input = torch::randn({1, 3, 32, 32});
// Run inference
std::vector<torch::jit::IValue> inputs;
inputs.push_back(input);
auto output = model.forward(inputs).toTensor();
std::cout << "Output shape: " << output.sizes() << std::endl;
// Get prediction
auto prediction = output.argmax(1).item<int>();
std::cout << "Predicted class: " << prediction << std::endl;
return 0;
}
Build with CMake:
cmake_minimum_required(VERSION 3.18)
project(inference)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_executable(inference main.cpp)
target_link_libraries(inference "${TORCH_LIBRARIES}")
set_property(TARGET inference PROPERTY CXX_STANDARD 17)
TorchScript Optimizations
The TorchScript compiler applies several optimizations:
# View the optimized graph
print(traced.graph)
# Run optimization passes explicitly
torch._C._jit_pass_inline(traced.graph)
torch._C._jit_pass_constant_propagation(traced.graph)
torch._C._jit_pass_fuse_linear(traced.graph)
Key optimizations:
- Constant folding: Pre-computes operations on constant values
- Dead code elimination: Removes unused computations
- Operator fusion: Combines sequences like Conv+BN+ReLU into single kernels
- Inlining: Flattens module calls to enable cross-module optimization
For inference, freeze the model to fold parameters into constants:
frozen = torch.jit.freeze(traced)
# Parameters are now constants in the graph — enables more optimization
Profiling TorchScript vs Python
Measure the actual benefit:
import time
model.eval()
traced.eval()
x = torch.randn(16, 3, 224, 224)
# Warmup
for _ in range(20):
model(x)
traced(x)
# Benchmark Python model
torch.cuda.synchronize() if x.is_cuda else None
start = time.perf_counter()
for _ in range(100):
model(x)
torch.cuda.synchronize() if x.is_cuda else None
python_ms = (time.perf_counter() - start) / 100 * 1000
# Benchmark TorchScript model
torch.cuda.synchronize() if x.is_cuda else None
start = time.perf_counter()
for _ in range(100):
traced(x)
torch.cuda.synchronize() if x.is_cuda else None
ts_ms = (time.perf_counter() - start) / 100 * 1000
print(f"Python: {python_ms:.2f} ms | TorchScript: {ts_ms:.2f} ms | "
f"Speedup: {python_ms/ts_ms:.2f}x")
Typical results: 10-30% speedup on CPU (Python overhead removed), 1-5% on GPU (Python overhead negligible vs CUDA compute).
Migration Path: TorchScript → torch.export
PyTorch 2.0+ introduces torch.export as the successor to TorchScript for ahead-of-time compilation:
# Modern approach (PyTorch 2.1+)
exported = torch.export.export(model, (example_input,))
# Produces ExportedProgram with full graph capture
print(exported.graph_module.graph)
# Save/load
torch.export.save(exported, "model_exported.pt2")
loaded = torch.export.load("model_exported.pt2")
Key differences from TorchScript:
- Uses TorchDynamo for graph capture (more complete than tracing)
- Handles dynamic shapes via symbolic constraints
- Better integration with
torch.compileoptimizations - No separate type system — uses standard Python types
For new projects, prefer torch.export when your target deployment supports it. TorchScript remains necessary for C++ LibTorch deployment and mobile until the export path fully replaces it.
The one thing to remember: TorchScript’s real value is Python-free execution — whether in C++ services via LibTorch or on mobile — and while torch.export is the future, TorchScript remains the production-proven path for non-Python deployment today.
See Also
- Python Pytorch Onnx Export Why converting a PyTorch model to ONNX format lets it run anywhere — from phones to cloud servers to web browsers.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
- Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.