Attention Mechanism — Deep Dive

Before the Math: Why RNNs Actually Failed

The narrative you hear — “RNNs couldn’t handle long sequences due to vanishing gradients” — is true but incomplete. LSTMs largely solved vanilla gradient vanishing. The deeper issue was sequential dependency.

With an RNN, computing the hidden state at timestep t requires the hidden state at timestep t-1. This is fundamentally serial. You can’t parallelize across the sequence dimension. On 2015-era hardware, training on long sequences was slow. On 2017-era hardware with increasingly powerful GPUs, it was a strategic dead-end — you couldn’t leverage the hardware efficiently.

The genius of attention wasn’t just “better context modeling.” It was making context modeling parallelizable.

The Full Attention Computation

Let’s be precise. Given an input sequence of n tokens, each projected to dimension d_model:

  1. Linear projections: Project input X into Q, K, V matrices using learned weight matrices W_Q, W_K, W_V ∈ ℝ^(d_model × d_k)
  2. Score matrix: Compute S = QK^T ∈ ℝ^(n×n) — every token attends to every token
  3. Scale: Divide by √d_k to control variance
  4. Mask (optional): For causal/decoder attention, add -∞ to future positions before softmax (they become 0 after softmax)
  5. Softmax: Apply row-wise softmax to get A ∈ ℝ^(n×n), where each row sums to 1
  6. Output: Compute AV ∈ ℝ^(n×d_v) — weighted sum of values
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights

The resulting complexity is O(n²·d) in time and O(n²) in memory. For n=512 tokens this is fine. For n=100,000 tokens (a long document), it becomes the primary bottleneck.

Multi-Head Attention in Detail

The single-head computation runs h times in parallel, each with its own Q/K/V projections:

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.h = num_heads
        self.d_k = d_model // num_heads
        
        self.W_Q = torch.nn.Linear(d_model, d_model)
        self.W_K = torch.nn.Linear(d_model, d_model)
        self.W_V = torch.nn.Linear(d_model, d_model)
        self.W_O = torch.nn.Linear(d_model, d_model)
    
    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.h, self.d_k)
        return x.transpose(1, 2)  # (batch, heads, seq, d_k)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        Q = self.split_heads(self.W_Q(Q), batch_size)
        K = self.split_heads(self.W_K(K), batch_size)
        V = self.split_heads(self.W_V(V), batch_size)
        
        x, _ = scaled_dot_product_attention(Q, K, V, mask)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
        return self.W_O(x)

In GPT-3 (2020), this runs with 96 heads, d_model=12288, 96 layers. Each layer’s attention block has ~150M parameters just in the projection matrices.

Positional Encoding: The Part People Skip

Attention is permutation invariant by design — shuffle the tokens and you get shuffled outputs but the same relative relationships. This means the model has no idea where in the sequence a token appears.

You have to inject position information explicitly.

Original transformers (2017): Sinusoidal encodings added to token embeddings before the first layer. Fixed, not learned.

BERT/GPT style: Learned absolute position embeddings. Work fine for sequences up to the max length seen in training. Fail for longer sequences.

Rotary Position Embedding (RoPE): The current dominant approach, used in LLaMA, Mistral, Gemma, and most 2023+ models. Instead of adding positional information, RoPE rotates the query and key vectors in a way that encodes relative position in the dot product itself.

The key property: the attention score between position m and position n depends only on (m-n), not on absolute values. This generalizes better to unseen sequence lengths and can be extended (with some degradation) far beyond training context via techniques like YaRN or LongRoPE.

FlashAttention: Making O(n²) Practical

The theoretical O(n²) complexity is actually not the bottleneck in practice for most sequence lengths. The real bottleneck is memory bandwidth — repeatedly reading and writing the attention matrix from HBM (GPU high-bandwidth memory) to SRAM (fast on-chip memory).

FlashAttention (Dao et al., Stanford, 2022) reorders the computation to keep intermediate results in SRAM as long as possible, computing attention in tiles. It produces the exact same result as standard attention, but:

  • 2-4x faster on A100 GPUs for typical context lengths
  • 10-20x less memory for the attention matrix (O(n) instead of O(n²))
  • Became standard in all major frameworks by 2023

FlashAttention-2 (2023) and FlashAttention-3 (2024) pushed this further with better parallelism and support for Hopper (H100) architecture features.

Sparse and Linear Attention: The Alternatives

Several approaches try to escape O(n²) entirely:

Sparse attention (Longformer, BigBird): Each token only attends to a local window + a few global tokens. O(n·k) complexity for window size k. Works well for long documents but misses long-range relationships.

Linear attention (Performer, RWKV): Approximates the softmax kernel with feature maps, reducing to O(n·d). In theory. In practice, quality drops noticeably for in-context learning tasks and the approximation is worst for the high-variance, high-information queries where attention matters most.

State Space Models (Mamba): Not attention at all — uses structured state space models that are O(n) in both memory and compute. As of early 2025, competitive with transformers on many benchmarks but not yet established at GPT-4 scale. The jury is still out on whether Mamba or hybrid transformer-Mamba models will displace pure attention for long-context tasks.

The honest answer: for most frontier models in 2026, vanilla multi-head attention with FlashAttention is still the dominant choice.

Cross-Attention: Conditioning on External State

In encoder-decoder architectures (T5, original transformer for translation), the decoder uses cross-attention where:

  • Q comes from the decoder’s current state
  • K and V come from the encoder’s output

This is what allows the decoder to “read” the source sequence selectively at each generation step. Modern multimodal models (GPT-4V, Gemini) use cross-attention to let the language model attend to image patch embeddings.

Diffusion models like Stable Diffusion use cross-attention to condition image generation on text embeddings — the denoising U-Net attends to CLIP or T5 encodings of the prompt.

What Attention Heads Actually Learn

Early interpretability work (Vig 2019, Clark et al. 2019) found that BERT attention heads specialize:

  • Some heads consistently attend to the previous token
  • Some heads track syntactic dependencies (subject → verb)
  • Some heads resolve coreferent pronouns across long distances
  • Some heads attend mostly to punctuation (probable context boundaries)

But heads don’t cleanly decompose into “one job per head.” More recent mechanistic interpretability work (Anthropic, 2022-2024) using activation patching and circuit analysis shows that model behavior emerges from circuits spanning multiple heads and layers — not from individual heads acting independently.

The attention layer also isn’t the only information-processing site. The MLP sublayers in transformers (roughly 2/3 of parameters in a typical model) act as key-value memories, storing factual associations. Attention routes information; MLPs retrieve and transform it.

The KV Cache: Why Inference Is Different from Training

During training, you compute attention over the full sequence in parallel. During inference, you generate tokens autoregressively — one at a time. If you recomputed K and V for all previous tokens at each step, inference would be O(n²) per token.

The KV cache stores the K and V projections for all previously generated tokens. Generating each new token only requires computing new Q (for the new token) and new K/V (for the new token), then attending against the cached K/V. This reduces inference attention to O(n) per step.

The tradeoff: memory. The KV cache grows linearly with sequence length and batch size. A single A100 (80GB) serving a 70B model with 128k context length can only run very small batches before the KV cache exhausts memory. Multi-Query Attention (MQA) and Grouped-Query Attention (GQA, used in LLaMA 3) address this by having multiple query heads share fewer K/V heads — dramatically reducing KV cache size with minimal quality loss.

One Thing to Remember

Attention’s O(n²) cost is a real ceiling — it’s why 1M-token context windows are engineering challenges, not just hyperparameter choices. Every advance in long-context modeling (FlashAttention, RoPE extensions, GQA, sparse attention) is fundamentally an attack on that quadratic scaling. The problem isn’t solved; it’s managed.

aideep-learningtransformersattentionnlparchitecture

See Also

  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Batch Normalization The 2015 trick that let researchers train much deeper neural networks — why keeping numbers in the right range makes AI learn 10x faster.
  • Convolutional Neural Networks How AI learned to see — the surprisingly simple idea behind face recognition, self-driving cars, and medical imaging.
  • Dropout Regularization How randomly switching off neurons during training makes AI models that generalize better — the counterintuitive trick that stopped neural networks from memorizing everything.
  • Generative Adversarial Networks How two AI networks competing against each other created the technology behind deepfakes, AI art, and synthetic data — the forger vs. the detective.