Sparse Attention — Deep Dive
Complexity Analysis: What Quadratic Scaling Really Means
For standard self-attention in a single layer with sequence length $n$ and head dimension $d$:
FLOPs:
- Query/Key/Value projections: $3 \times 2nd^2$ FLOPs
- Attention scores $QK^T$: $2n^2 d$ FLOPs
- Attention output $\text{Attn}(Q,K,V) \times V$: $2n^2 d$ FLOPs
- Output projection: $2nd^2$ FLOPs
Total: $\approx 4n^2 d + 8nd^2$ FLOPs. For $n > 2d$: attention computation dominates ($4n^2 d$). At $n = 8d$ (typical for models where $d = d_{model}/n_{heads} = 128, n = 1024$): $4 \times 1024^2 \times 128 \approx 537M$ FLOPs per layer just for attention.
Memory (KV cache at inference): Each token requires storing K and V: $2 \times n_{heads} \times d_{head} \times 2$ bytes (FP16). For Llama-3.1-70B at 128K context: $80 \times 128 \times 2 \times 128000 \times 2 = 41$ GB per request.
This KV cache size is what makes long-context inference expensive — not the computation.
Ring Attention: Distributed Long-Context Training
Li et al. (2023) “Ring Attention with Blockwise Transformers” enables training on arbitrarily long sequences by distributing attention computation across devices.
In ring attention, $N$ devices form a logical ring. Each device holds a chunk of the sequence $[i \times (n/N), (i+1) \times (n/N)]$:
- Each device computes local attention (its chunk’s queries attending to its chunk’s keys/values)
- Keys and values are passed to the next device in the ring
- Each device accumulates attention scores against all received K/V chunks
- After $N$ passes (one full ring rotation), each device has the full attention output for its query chunk
Memory per device: $O(n/N)$ — scales with 1/N as you add devices. Can train on arbitrarily long sequences given enough devices.
Used for training Llama 3.1 on 128K context: sequence parallelism distributed across 8+ GPUs per sequence.
Mamba and Selective State Space Models vs. Sparse Attention
Gu & Dao (2023) “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” offers an alternative to attention that scales as $O(n)$ in sequence length.
Key mechanism: Input-dependent state space model. Unlike fixed SSMs (linear recurrence with constant parameters), Mamba’s transition matrices $A, B, C$ are functions of the input:
$$\bar{A}(x) = \exp(\Delta(x) A), \quad \bar{B}(x) = \Delta(x) B(x)$$
Where $\Delta(x) = \text{softplus}(W_\Delta x)$ is a learned input-dependent step size.
The “selectivity” allows the model to focus on or ignore specific inputs — analogous to attention’s ability to weight tokens by relevance.
Efficiency: Mamba processes sequences as recurrences (O(1) per step at inference) while training using parallel scan (O(n log n), effectively O(n) in practice). No KV cache needed at inference — just a fixed-size state vector per layer.
Comparison to transformers:
- At sequence length 1K: transformer is faster (high-throughput attention)
- At sequence length 16K+: Mamba is faster due to O(n) vs. O(n²)
- Quality: Mamba 3B matches Llama-3 8B on many language tasks — competitive but not clearly superior
Hybrid models: Falcon-Mamba, Jamba (AI21), and Zamba combine Mamba layers with attention layers (typically every 6–8 layers). Hybrid models achieve near-transformer quality with Mamba’s efficiency at long contexts.
Lost in the Middle: The Practical Long-Context Challenge
Liu et al. (2023) “Lost in the Middle: How Language Models Use Long Contexts” revealed a critical practical limitation: even models that technically support long contexts don’t use the information in the middle effectively.
Experimental setup: Query a model with facts distributed at different positions in a long context. Measure retrieval accuracy vs. position.
Finding: Models consistently perform best when relevant information is at the beginning or end of the context. Information in the “middle” of long contexts is retrieved 30–50% less accurately than information at the boundaries.
Mechanism hypothesis: The positional encoding and attention patterns that make transformers work create implicit position-dependent biases. “Primacy” and “recency” effects (common in human memory) emerge in transformer attention.
Practical implication: When constructing prompts with retrieved context, put the most important context at the beginning or end, not the middle. This is a real deployment concern for RAG systems using 32K+ context.
Needle-in-a-Haystack Evaluation
Standard language modeling benchmarks don’t test long-context capability specifically. “Needle in a Haystack” (NIAH) is the standard stress test:
- Fill the context with irrelevant but coherent text (“haystacks” — Paul Graham essays, Wikipedia, etc.)
- Insert a specific factual “needle” at a specific position
- Ask the model a question that requires finding the needle
Performance is measured across:
- Context depth: where the needle is (0% = beginning, 100% = end)
- Context length: total length of haystacks (1K to 1M tokens)
Visualization: a 2D heatmap of accuracy at each (depth, length) combination. Models show characteristic performance degradation patterns.
Findings across models (2024):
- GPT-4 Turbo (128K): Good performance across most positions but some degradation in middle at 128K
- Claude 3 Opus (200K): Strong performance across all positions up to ~100K; some degradation beyond
- Gemini 1.5 Pro (1M): Maintained high accuracy across the full 1M context window — the best long-context performance documented
- Open-source models: Typically show significant “lost in middle” effects beyond their training context length
One thing to remember: Having a long context window doesn’t mean a model uses it effectively — the “lost in the middle” phenomenon and needle-in-haystack testing reveal that position-dependent attention biases create real practical limitations that affect how RAG and long-context applications should be designed.
See Also
- Mixture Of Experts How GPT-4 and Mixtral use specialized sub-networks to handle different types of questions — the architecture secret that lets AI be huge without being slow.
- Neural Scaling Laws Why bigger AI keeps getting better — the mathematical relationships that let researchers predict how smart an AI will be before they finish building it.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.