PyTorch TorchScript — Deep Dive

Tracing in Detail

Tracing converts a model by executing it and recording operations:

import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            ConvBlock(3, 32),
            nn.MaxPool2d(2),
            ConvBlock(32, 64),
            nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(1)
        return self.classifier(x)

model = SimpleCNN()
model.eval()

# Trace with example input
example = torch.randn(1, 3, 32, 32)
traced = torch.jit.trace(model, example)

# Save as standalone artifact
traced.save("simple_cnn.pt")

# Verify output matches
with torch.no_grad():
    orig_out = model(example)
    traced_out = traced(example)
    assert torch.allclose(orig_out, traced_out, atol=1e-6)

Tracing warnings appear when the tracer detects potential issues — pay attention to them:

TracerWarning: Converting a tensor to a Python boolean might cause
the trace to be incorrect. We can't record the data flow of Python
values, so this value will be treated as a constant.

This warning means your model has control flow that tracing can’t capture. Switch to scripting for those modules.

Scripting in Detail

Scripting compiles Python source code to TorchScript IR:

class DynamicRouter(nn.Module):
    """Model with data-dependent control flow — must be scripted."""

    def __init__(self, d_model: int, num_experts: int):
        super().__init__()
        self.gate = nn.Linear(d_model, num_experts)
        self.experts = nn.ModuleList([
            nn.Linear(d_model, d_model) for _ in range(num_experts)
        ])
        self.threshold = 0.5

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Gate scores determine which expert processes each sample
        scores = torch.softmax(self.gate(x), dim=-1)
        max_score, expert_idx = scores.max(dim=-1)

        outputs = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            mask = expert_idx == i
            if mask.any():
                outputs[mask] = expert(x[mask])

        return outputs

model = DynamicRouter(256, 4)
model.eval()

# Script — captures the for loop and conditional
scripted = torch.jit.script(model)
scripted.save("dynamic_router.pt")

Type Annotations for TorchScript

TorchScript requires explicit types. Common patterns:

from typing import List, Optional, Tuple, Dict

@torch.jit.script
def process_outputs(
    logits: torch.Tensor,
    labels: Optional[torch.Tensor] = None,
    top_k: int = 5,
) -> Tuple[torch.Tensor, List[int]]:
    probs = torch.softmax(logits, dim=-1)
    values, indices = probs.topk(top_k, dim=-1)

    # Build list of top predictions for first sample
    top_classes: List[int] = []
    for i in range(top_k):
        top_classes.append(indices[0, i].item())

    if labels is not None:
        return probs, top_classes
    else:
        return probs, top_classes

Types that TorchScript supports: int, float, bool, str, Tensor, List[T], Dict[K, V], Tuple[T, ...], Optional[T], and @torch.jit.script decorated classes.

Hybrid Tracing and Scripting

For complex models, combine both approaches:

class Encoder(nn.Module):
    """Simple encoder — no control flow, can be traced."""
    def __init__(self, d_model):
        super().__init__()
        self.layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, 8, batch_first=True),
            num_layers=6,
        )

    def forward(self, x):
        return self.layers(x)

class ConditionalDecoder(nn.Module):
    """Decoder with control flow — must be scripted."""
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        self.temperature: float = 1.0

    @torch.jit.export
    def set_temperature(self, t: float):
        self.temperature = t

    def forward(self, x: torch.Tensor, greedy: bool = True) -> torch.Tensor:
        logits = self.proj(x) / self.temperature
        if greedy:
            return logits.argmax(dim=-1)
        else:
            probs = torch.softmax(logits, dim=-1)
            return torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
                probs.size(0), probs.size(1)
            )

class FullModel(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        # Trace the encoder (no control flow)
        self.encoder = torch.jit.trace(
            Encoder(d_model),
            torch.randn(1, 16, d_model),
        )
        # Script the decoder (has control flow)
        self.decoder = torch.jit.script(ConditionalDecoder(d_model, vocab_size))

    def forward(self, x: torch.Tensor, greedy: bool = True) -> torch.Tensor:
        encoded = self.encoder(x)
        return self.decoder(encoded, greedy)

The @torch.jit.export decorator makes set_temperature available on the scripted module — methods not decorated with forward or @torch.jit.export are not accessible after scripting.

Loading in C++ with LibTorch

The primary production use case — running models in C++ without Python:

#include <torch/script.h>
#include <iostream>

int main() {
    // Load the TorchScript model
    torch::jit::script::Module model;
    try {
        model = torch::jit::load("simple_cnn.pt");
    } catch (const c10::Error& e) {
        std::cerr << "Error loading model: " << e.what() << std::endl;
        return 1;
    }

    model.eval();

    // Create input tensor
    auto input = torch::randn({1, 3, 32, 32});

    // Run inference
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(input);

    auto output = model.forward(inputs).toTensor();
    std::cout << "Output shape: " << output.sizes() << std::endl;

    // Get prediction
    auto prediction = output.argmax(1).item<int>();
    std::cout << "Predicted class: " << prediction << std::endl;

    return 0;
}

Build with CMake:

cmake_minimum_required(VERSION 3.18)
project(inference)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_executable(inference main.cpp)
target_link_libraries(inference "${TORCH_LIBRARIES}")
set_property(TARGET inference PROPERTY CXX_STANDARD 17)

TorchScript Optimizations

The TorchScript compiler applies several optimizations:

# View the optimized graph
print(traced.graph)

# Run optimization passes explicitly
torch._C._jit_pass_inline(traced.graph)
torch._C._jit_pass_constant_propagation(traced.graph)
torch._C._jit_pass_fuse_linear(traced.graph)

Key optimizations:

  • Constant folding: Pre-computes operations on constant values
  • Dead code elimination: Removes unused computations
  • Operator fusion: Combines sequences like Conv+BN+ReLU into single kernels
  • Inlining: Flattens module calls to enable cross-module optimization

For inference, freeze the model to fold parameters into constants:

frozen = torch.jit.freeze(traced)
# Parameters are now constants in the graph — enables more optimization

Profiling TorchScript vs Python

Measure the actual benefit:

import time

model.eval()
traced.eval()
x = torch.randn(16, 3, 224, 224)

# Warmup
for _ in range(20):
    model(x)
    traced(x)

# Benchmark Python model
torch.cuda.synchronize() if x.is_cuda else None
start = time.perf_counter()
for _ in range(100):
    model(x)
torch.cuda.synchronize() if x.is_cuda else None
python_ms = (time.perf_counter() - start) / 100 * 1000

# Benchmark TorchScript model
torch.cuda.synchronize() if x.is_cuda else None
start = time.perf_counter()
for _ in range(100):
    traced(x)
torch.cuda.synchronize() if x.is_cuda else None
ts_ms = (time.perf_counter() - start) / 100 * 1000

print(f"Python: {python_ms:.2f} ms | TorchScript: {ts_ms:.2f} ms | "
      f"Speedup: {python_ms/ts_ms:.2f}x")

Typical results: 10-30% speedup on CPU (Python overhead removed), 1-5% on GPU (Python overhead negligible vs CUDA compute).

Migration Path: TorchScript → torch.export

PyTorch 2.0+ introduces torch.export as the successor to TorchScript for ahead-of-time compilation:

# Modern approach (PyTorch 2.1+)
exported = torch.export.export(model, (example_input,))

# Produces ExportedProgram with full graph capture
print(exported.graph_module.graph)

# Save/load
torch.export.save(exported, "model_exported.pt2")
loaded = torch.export.load("model_exported.pt2")

Key differences from TorchScript:

  • Uses TorchDynamo for graph capture (more complete than tracing)
  • Handles dynamic shapes via symbolic constraints
  • Better integration with torch.compile optimizations
  • No separate type system — uses standard Python types

For new projects, prefer torch.export when your target deployment supports it. TorchScript remains necessary for C++ LibTorch deployment and mobile until the export path fully replaces it.

The one thing to remember: TorchScript’s real value is Python-free execution — whether in C++ services via LibTorch or on mobile — and while torch.export is the future, TorchScript remains the production-proven path for non-Python deployment today.

pythonmachine-learningpytorch

See Also

  • Python Pytorch Onnx Export Why converting a PyTorch model to ONNX format lets it run anywhere — from phones to cloud servers to web browsers.
  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.