NumPy Einsum — Deep Dive

Technical foundation

np.einsum implements Einstein summation convention: repeated indices are implicitly summed. Under the hood, it parses the subscript string into a contraction plan, optionally optimizes the contraction order, and dispatches to either a custom C loop or BLAS routines.

Contraction path optimization

For multi-operand expressions, the contraction order determines both memory usage and compute cost. Consider three matrices A(100×200), B(200×300), C(300×50):

import numpy as np

A = np.random.randn(100, 200)
B = np.random.randn(200, 300)
C = np.random.randn(300, 50)

# Order 1: (A @ B) @ C → intermediate is 100×300 = 30K elements, then 100×50
# Order 2: A @ (B @ C) → intermediate is 200×50 = 10K elements, then 100×50

# Let NumPy find the best path
path, info = np.einsum_path('ij,jk,kl->il', A, B, C, optimize='optimal')
print(path)
# ['einsum_path', (1, 2), (0, 1)]  → contracts B@C first, then A@result

The einsum_path function returns the optimal contraction sequence and a string showing FLOP counts. For three operands, the difference is small. For tensor networks with 5+ operands, choosing the wrong order can mean 1000x slower execution.

BLAS dispatch

NumPy’s einsum can detect patterns that map to BLAS calls:

PatternBLAS equivalent
'i,i->'ddot (dot product)
'ij,jk->ik'dgemm (matrix multiply)
'ij,j->i'dgemv (matrix-vector)
'...ij,...jk->...ik'Batched dgemm

When optimize=True and the pattern matches, einsum dispatches to BLAS, achieving the same speed as np.matmul. Without optimization, it falls back to a generic C loop that is significantly slower for large matrices.

# Force BLAS path
result = np.einsum('ij,jk->ik', A, B, optimize=True)

# Check if it matches np.matmul speed
%timeit np.einsum('ij,jk->ik', A, B, optimize=True)  # ~same as matmul
%timeit np.einsum('ij,jk->ik', A, B, optimize=False)  # 5-50x slower

Advanced patterns

Batch operations

Einsum handles batch dimensions naturally:

# Batch of 64 matrix multiplications, each 32x32 @ 32x16
A = np.random.randn(64, 32, 32)
B = np.random.randn(64, 32, 16)
result = np.einsum('bij,bjk->bik', A, B, optimize=True)  # (64, 32, 16)

Tensor contractions

In physics and machine learning, tensors with 4+ indices are common:

# Attention mechanism: Q(batch, heads, seq, d) @ K^T(batch, heads, d, seq)
Q = np.random.randn(8, 12, 128, 64)
K = np.random.randn(8, 12, 128, 64)
attn_scores = np.einsum('bhsd,bhtd->bhst', Q, K, optimize=True)
# (8, 12, 128, 128) — attention weight matrix per head per batch

Bilinear forms

# x^T M y for batches of vectors and a shared matrix
x = np.random.randn(1000, 10)    # (batch, n)
M = np.random.randn(10, 10)      # (n, n)
y = np.random.randn(1000, 10)    # (batch, n)

result = np.einsum('bi,ij,bj->b', x, M, y, optimize=True)  # (1000,)

Without einsum, this would require (x @ M * y).sum(axis=1) — creating a (1000, 10) intermediate.

Kronecker product

A = np.random.randn(3, 3)
B = np.random.randn(4, 4)
kron = np.einsum('ij,kl->ikjl', A, B).reshape(12, 12)
# Equivalent to np.kron(A, B) but shows the index structure

Memory and performance tradeoffs

Avoiding temporaries

Einsum’s primary performance advantage over chained NumPy calls is eliminating intermediate allocations:

# Three temporaries (squared, summed, sqrt'd)
norms = np.sqrt((A ** 2).sum(axis=1))

# One pass with einsum (no temporaries for the sum)
norms = np.sqrt(np.einsum('ij,ij->i', A, A))

For large arrays, the memory savings can be more significant than the compute savings — fewer allocations means less GC pressure and better cache utilization.

When einsum is slower

Einsum’s generic loop does not use SIMD optimizations for non-BLAS patterns. A fused NumPy expression can sometimes be faster:

# Element-wise multiply and sum — einsum generic loop
result1 = np.einsum('ij,ij->', A, B)

# NumPy uses optimized multiply + sum
result2 = (A * B).sum()

# For this pattern, the NumPy version may be faster because
# np.multiply and np.sum each use SIMD internally

Profile before assuming einsum is faster.

The opt_einsum library

For complex tensor contractions, the opt_einsum library provides better optimization than NumPy’s built-in optimizer:

import opt_einsum as oe

# Find optimal contraction for a complex expression
expr = oe.contract_expression(
    'pqrs,tuqv,wurt,xyzs->pxwz',
    (4, 4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4),
    optimize='dp',  # dynamic programming optimizer
)

# Reuse the optimized expression
result = expr(A, B, C, D)

opt_einsum also integrates with PyTorch and TensorFlow, using the same subscript notation across frameworks.

Ellipsis notation

For variable-rank batch dimensions, use ...:

# Works regardless of how many batch dimensions exist
def batched_trace(A):
    return np.einsum('...ii->...', A)

print(batched_trace(np.eye(3)).shape)                      # ()
print(batched_trace(np.random.randn(5, 3, 3)).shape)       # (5,)
print(batched_trace(np.random.randn(2, 5, 3, 3)).shape)    # (2, 5)

Debugging einsum expressions

When a subscript string gives unexpected results, decompose it:

def einsum_shape(subscripts, *shapes):
    """Preview output shape without creating arrays."""
    dummy = [np.empty(s) for s in shapes]
    return np.einsum(subscripts, *dummy).shape

print(einsum_shape('ij,jk->ik', (3, 4), (4, 5)))  # (3, 5)
print(einsum_shape('bij,bjk->bik', (2, 3, 4), (2, 4, 5)))  # (2, 3, 5)

Also use np.einsum_path to understand what einsum is actually doing:

_, info = np.einsum_path('ijk,jkl,lm->im', A, B, C, optimize='optimal')
print(info)
# Shows: contraction order, FLOP count per step, size of intermediates

The one thing to remember: Einsum is most powerful when it eliminates intermediate arrays in multi-step tensor computations — use optimize=True and einsum_path to ensure it finds the fast BLAS path.

pythonnumpydata-science

See Also