Sparse Attention — Core Concepts
The Quadratic Bottleneck
Standard self-attention’s time and memory complexity is $O(n^2)$ in sequence length $n$. For a 1K token context: fine. For 128K tokens: $128^2 = 16,384$ times more expensive than 1K. For 1M tokens: 1 trillion attention values to compute.
This quadratic scaling is the fundamental barrier to long-context processing. Sparse attention patterns break this barrier by computing only a subset of attention connections.
The key insight (empirically validated by many papers): full attention between all tokens isn’t necessary for most tasks. Most relevant context comes from nearby tokens (local attention), a few global summary tokens, and occasional long-range connections.
Longformer: Local + Global Attention
Beltagy et al. (Allen AI, 2020) “Longformer: The Long-Document Transformer” combined two attention patterns:
Sliding window attention: Each token attends to $w$ tokens on each side (a window of $2w + 1$). Complexity: $O(n \cdot w)$ — linear in sequence length for fixed window size.
For $w = 512$ on a 4096-token document: 4096 × 1024 ≈ 4M attention computations vs. 4096² = 16.7M for full attention.
Global attention: Special tokens (e.g., [CLS] for classification, [SEP], or user-specified task tokens) attend to ALL tokens and are attended to by ALL tokens. These act as information aggregators for the whole sequence. Complexity: $O(g \cdot n)$ for $g$ global tokens.
Combined complexity: $O(n \cdot (w + g))$ — linear in sequence length. Longformer processed up to 4096 tokens (extending BERT) with this approach.
Implementation: Longformer required custom CUDA kernels for efficient sliding window attention — standard matrix multiplication doesn’t benefit from the sparse pattern.
BigBird: Random + Local + Global
Zaheer et al. (Google, 2020) “Big Bird: Transformers for Longer Sequences” added random attention to Longformer’s local + global pattern:
Random attention: Each token attends to $r$ randomly selected tokens across the sequence. This provides “approximate” global connectivity — with high probability, any two tokens are connected through at most 2 hops via random connections.
BigBird proved (Theorem 1): the combination of local (sliding window) + global + random attention approximates full attention, in the sense that any function computable by full attention can also be computed by this sparse pattern (with slightly longer context path). This is the “universal approximation” property for sparse attention.
Context used in BERT-like models: BigBird extended BERT to 4096 tokens. On genomics tasks (long DNA sequences up to 4K nucleotides), BigBird demonstrated state-of-the-art performance by processing full sequences that previous models couldn’t handle.
Swin Transformer: Hierarchical Windows
For vision, Liu et al. (2021) “Swin Transformer” used a different sparse attention approach: non-overlapping windows.
Each window is $M × M$ patches. Attention is computed within each window: $O(M^2)$ per window, $O(N)$ total for $N$ patches.
To enable cross-window communication: alternate between regular window partitioning and shifted window partitioning (offset by $M/2$ tokens). This allows windows to overlap across alternate layers, providing long-range connections.
Swin Transformer won multiple vision benchmarks in 2021 and became the standard backbone for many vision tasks, demonstrating that hierarchical local attention can match global attention performance on image tasks.
Linear Attention Approximations
A different approach: approximate full attention with $O(n)$ computation by eliminating softmax.
Standard attention: $\text{Attn}(Q, K, V) = \text{softmax}(QK^T / \sqrt{d}) V$
The softmax prevents factorization: $QK^T$ must be computed before applying softmax, requiring $O(n^2)$ intermediate storage.
Linear attention (Katharopoulos et al., 2020): Replace softmax with a kernel function $k(Q_i, K_j) = \phi(Q_i)^T \phi(K_j)$, where $\phi$ is a feature map:
$$\text{Attn}(Q_i, K, V) = \frac{\sum_j \phi(Q_i)^T \phi(K_j) V_j}{\sum_j \phi(Q_i)^T \phi(K_j)} = \frac{\phi(Q_i)^T (\sum_j \phi(K_j) V_j^T)}{\phi(Q_i)^T (\sum_j \phi(K_j))}$$
The key insight: compute $\sum_j \phi(K_j) V_j^T$ once (cumulative $K$-$V$ interaction), then reuse for each query. Total complexity: $O(n \cdot d^2)$ — linear in sequence length (for fixed $d$).
Quality tradeoff: Linear attention approximations produce measurably lower quality than softmax attention at equivalent model sizes. Several variants (Performer, cosFormer, RetNet) have improved quality but haven’t closed the gap for language tasks.
FlashAttention + Long Context
FlashAttention (see gpu-computing topic) enables memory-efficient exact attention up to 128K+ tokens by tiling attention computation in SRAM. This is different from sparse attention — it’s still exact (full) attention, but memory-efficient.
Combined approach in modern long-context models:
- FlashAttention for memory efficiency of exact attention
- Rotary positional encoding (RoPE) that generalizes to longer sequences than trained on
- YaRN / LongRoPE: Fine-tuning techniques to extend context by re-scaling positional encodings
Llama 3.1 was trained on 128K context using this approach. Gemini 1.5 Pro demonstrated 1M token context — achieved through a combination of sparse cross-attention mechanisms and efficient hardware utilization.
One thing to remember: Sparse attention patterns (local, global, random, hierarchical) and efficient exact attention (FlashAttention) are complementary approaches to the long-context problem — the field has moved toward longer exact attention rather than approximating it, driven by FlashAttention’s memory savings.
See Also
- Mixture Of Experts How GPT-4 and Mixtral use specialized sub-networks to handle different types of questions — the architecture secret that lets AI be huge without being slow.
- Neural Scaling Laws Why bigger AI keeps getting better — the mathematical relationships that let researchers predict how smart an AI will be before they finish building it.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.