NumPy Einsum — Deep Dive
Technical foundation
np.einsum implements Einstein summation convention: repeated indices are implicitly summed. Under the hood, it parses the subscript string into a contraction plan, optionally optimizes the contraction order, and dispatches to either a custom C loop or BLAS routines.
Contraction path optimization
For multi-operand expressions, the contraction order determines both memory usage and compute cost. Consider three matrices A(100×200), B(200×300), C(300×50):
import numpy as np
A = np.random.randn(100, 200)
B = np.random.randn(200, 300)
C = np.random.randn(300, 50)
# Order 1: (A @ B) @ C → intermediate is 100×300 = 30K elements, then 100×50
# Order 2: A @ (B @ C) → intermediate is 200×50 = 10K elements, then 100×50
# Let NumPy find the best path
path, info = np.einsum_path('ij,jk,kl->il', A, B, C, optimize='optimal')
print(path)
# ['einsum_path', (1, 2), (0, 1)] → contracts B@C first, then A@result
The einsum_path function returns the optimal contraction sequence and a string showing FLOP counts. For three operands, the difference is small. For tensor networks with 5+ operands, choosing the wrong order can mean 1000x slower execution.
BLAS dispatch
NumPy’s einsum can detect patterns that map to BLAS calls:
| Pattern | BLAS equivalent |
|---|---|
'i,i->' | ddot (dot product) |
'ij,jk->ik' | dgemm (matrix multiply) |
'ij,j->i' | dgemv (matrix-vector) |
'...ij,...jk->...ik' | Batched dgemm |
When optimize=True and the pattern matches, einsum dispatches to BLAS, achieving the same speed as np.matmul. Without optimization, it falls back to a generic C loop that is significantly slower for large matrices.
# Force BLAS path
result = np.einsum('ij,jk->ik', A, B, optimize=True)
# Check if it matches np.matmul speed
%timeit np.einsum('ij,jk->ik', A, B, optimize=True) # ~same as matmul
%timeit np.einsum('ij,jk->ik', A, B, optimize=False) # 5-50x slower
Advanced patterns
Batch operations
Einsum handles batch dimensions naturally:
# Batch of 64 matrix multiplications, each 32x32 @ 32x16
A = np.random.randn(64, 32, 32)
B = np.random.randn(64, 32, 16)
result = np.einsum('bij,bjk->bik', A, B, optimize=True) # (64, 32, 16)
Tensor contractions
In physics and machine learning, tensors with 4+ indices are common:
# Attention mechanism: Q(batch, heads, seq, d) @ K^T(batch, heads, d, seq)
Q = np.random.randn(8, 12, 128, 64)
K = np.random.randn(8, 12, 128, 64)
attn_scores = np.einsum('bhsd,bhtd->bhst', Q, K, optimize=True)
# (8, 12, 128, 128) — attention weight matrix per head per batch
Bilinear forms
# x^T M y for batches of vectors and a shared matrix
x = np.random.randn(1000, 10) # (batch, n)
M = np.random.randn(10, 10) # (n, n)
y = np.random.randn(1000, 10) # (batch, n)
result = np.einsum('bi,ij,bj->b', x, M, y, optimize=True) # (1000,)
Without einsum, this would require (x @ M * y).sum(axis=1) — creating a (1000, 10) intermediate.
Kronecker product
A = np.random.randn(3, 3)
B = np.random.randn(4, 4)
kron = np.einsum('ij,kl->ikjl', A, B).reshape(12, 12)
# Equivalent to np.kron(A, B) but shows the index structure
Memory and performance tradeoffs
Avoiding temporaries
Einsum’s primary performance advantage over chained NumPy calls is eliminating intermediate allocations:
# Three temporaries (squared, summed, sqrt'd)
norms = np.sqrt((A ** 2).sum(axis=1))
# One pass with einsum (no temporaries for the sum)
norms = np.sqrt(np.einsum('ij,ij->i', A, A))
For large arrays, the memory savings can be more significant than the compute savings — fewer allocations means less GC pressure and better cache utilization.
When einsum is slower
Einsum’s generic loop does not use SIMD optimizations for non-BLAS patterns. A fused NumPy expression can sometimes be faster:
# Element-wise multiply and sum — einsum generic loop
result1 = np.einsum('ij,ij->', A, B)
# NumPy uses optimized multiply + sum
result2 = (A * B).sum()
# For this pattern, the NumPy version may be faster because
# np.multiply and np.sum each use SIMD internally
Profile before assuming einsum is faster.
The opt_einsum library
For complex tensor contractions, the opt_einsum library provides better optimization than NumPy’s built-in optimizer:
import opt_einsum as oe
# Find optimal contraction for a complex expression
expr = oe.contract_expression(
'pqrs,tuqv,wurt,xyzs->pxwz',
(4, 4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4),
optimize='dp', # dynamic programming optimizer
)
# Reuse the optimized expression
result = expr(A, B, C, D)
opt_einsum also integrates with PyTorch and TensorFlow, using the same subscript notation across frameworks.
Ellipsis notation
For variable-rank batch dimensions, use ...:
# Works regardless of how many batch dimensions exist
def batched_trace(A):
return np.einsum('...ii->...', A)
print(batched_trace(np.eye(3)).shape) # ()
print(batched_trace(np.random.randn(5, 3, 3)).shape) # (5,)
print(batched_trace(np.random.randn(2, 5, 3, 3)).shape) # (2, 5)
Debugging einsum expressions
When a subscript string gives unexpected results, decompose it:
def einsum_shape(subscripts, *shapes):
"""Preview output shape without creating arrays."""
dummy = [np.empty(s) for s in shapes]
return np.einsum(subscripts, *dummy).shape
print(einsum_shape('ij,jk->ik', (3, 4), (4, 5))) # (3, 5)
print(einsum_shape('bij,bjk->bik', (2, 3, 4), (2, 4, 5))) # (2, 3, 5)
Also use np.einsum_path to understand what einsum is actually doing:
_, info = np.einsum_path('ijk,jkl,lm->im', A, B, C, optimize='optimal')
print(info)
# Shows: contraction order, FLOP count per step, size of intermediates
The one thing to remember: Einsum is most powerful when it eliminates intermediate arrays in multi-step tensor computations — use optimize=True and einsum_path to ensure it finds the fast BLAS path.
See Also
- Python Bokeh Get an intuitive feel for Bokeh so Python behavior stops feeling unpredictable.
- Python Numpy Advanced Indexing How to cherry-pick exactly the data you want from a NumPy array using lists, masks, and fancy tricks.
- Python Numpy Broadcasting Rules How NumPy magically makes different-sized arrays work together without you writing any loops.
- Python Numpy Fft Spectral How NumPy breaks apart a signal into its hidden frequencies — like separating a chord into individual notes.
- Python Numpy Memory Views Why NumPy arrays can share the same data without copying it — and how that makes your code fast but occasionally surprising.