Speculative Decoding — Deep Dive

Formal Proof: Exact Target Distribution

Theorem (Leviathan et al., 2022): The speculative decoding procedure generates tokens with exactly the distribution $p(x|context)$ of the target model.

Full proof:

Let $T$ be the position of the first rejection (or $T = k+1$ if all $k$ drafts accepted). The output distribution at position $n+T$:

For $T = k+1$ (all accepted): $P(\text{output} = x) = P(\text{all } k \text{ accepted}) \cdot p(x | context_{n+k})$

For $T = j \leq k$ (rejected at position $j$): $P(\text{output} = x) = P(\text{accept } 1..j-1) \cdot P(\text{reject at } j) \cdot p’(x | context_{n+j-1})$

Where the resampled distribution is: $$p’(x) = \frac{\max(0, p(x) - q(x))}{\sum_{x’} \max(0, p(x’) - q(x’))} = \frac{\max(0, p(x) - q(x))}{1 - \alpha}$$

Summing over all $T$: $$P(\text{output} = x) = \sum_{j=1}^{k+1} [\text{probability reaching step } j] \cdot [\text{probability output is x at step j}]$$

After expansion and cancellation (using the fact that probabilities sum to 1 and $\sum_x \max(0, p(x) - q(x)) = 1 - \alpha$), this simplifies to $p(x)$.

This is a non-trivial result: despite accepting/rejecting draft tokens based on their probability ratio, the marginal distribution of any accepted or resampled token is exactly $p(x)$.

Multi-Candidate and Tree Speculative Decoding

Standard speculative decoding generates one draft sequence (a linear chain). Tree speculative decoding generates multiple branching sequences and verifies all simultaneously.

Speculative Tree Construction: The draft model generates a tree with $b$ children per node at each of $k$ depths. Total nodes: $(b^{k+1} - 1)/(b - 1)$ per tree. Beam search with beam size $b$ for depth $k$ gives $b$ candidate complete sequences.

Tree Attention (Miao et al., 2023): When the target model verifies a tree of draft tokens, standard attention would require running the model once per path (exponentially expensive). Tree attention processes all nodes simultaneously by masking attention to ensure each node only attends to its own ancestors:

$$\text{mask}[i,j] = \begin{cases} 0 & \text{if node } j \text{ is an ancestor of node } i \ -\infty & \text{otherwise} \end{cases}$$

All nodes in the tree are processed in a single forward pass with this custom attention mask. The key-value cache is shared for common prefixes (ancestors), and per-path branches don’t need to recompute the shared prefix.

Speedup analysis: For a tree of depth $k=4$ and breadth $b=3$: 40 total draft tokens verified in one target model forward pass. Compared to sequential generation of 40 tokens (40 forward passes), the potential speedup is 40x — but acceptance rates in trees are typically lower than linear chains (early branches constrain later ones), bringing realistic speedup to 4–6x.

EAGLE-2: Dynamic Drafting Length

EAGLE (2024) uses feature-level speculative drafting. EAGLE-2 extended this with dynamic draft length.

Key insight: Not all tokens are equally predictable. Common words (“the”, “and”, punctuation) have high acceptance rates — it’s worth drafting more. Rare or contextually complex tokens have lower acceptance rates — drafting fewer is more efficient.

EAGLE-2 trains a confidence predictor alongside the draft model. After each draft step, the predictor estimates the probability that the next draft token will be accepted. If confidence is below threshold, stop drafting and call the target model.

This adaptive stopping achieves:

  • Higher acceptance rates (stop before hard tokens)
  • Better compute allocation (more calls on easy tokens, fewer on hard)
  • 3.5–5x speedup on diverse tasks (vs. 2.5–3.5x for fixed-length drafting)

Interaction With Continuous Batching

Modern LLM serving uses continuous batching (Orca, 2022) — as requests finish generating, new requests are immediately added to the batch. This maximizes GPU utilization.

Speculative decoding complicates continuous batching because:

  1. Different requests have different acceptance rates — some extend quickly, others don’t
  2. When a request accepts $k=5$ tokens, it might finish mid-batch, leaving other batch members waiting
  3. Speculative tree evaluation requires consistent batch structure during verification

Chunked speculative decoding: Process the draft phase in fixed-size chunks synchronized across the batch. This reduces GPU utilization relative to unconstrained speculative decoding but enables clean interaction with continuous batching.

Batch-level vs. request-level speculation: Some serving systems apply speculation at the batch level (all requests use the same draft model), others request-level (speculation only when beneficial per request). Request-level is more efficient but more complex to orchestrate.

Prefill-decode disaggregation: Modern large-scale LLM serving separates “prefill” (processing the input) from “decode” (generating tokens) on different hardware. Speculative decoding fits naturally in the decode phase — the draft model on less powerful hardware generates candidates, the target model on high-end hardware verifies.

PRM-Guided Speculative Decoding

An emerging combination: use a Process Reward Model (from chain-of-thought research) as the verification signal instead of simple probability comparison.

Standard speculative decoding accepts draft token $x$ with probability $\min(1, p(x)/q(x))$. PRM-guided speculative decoding accepts based on:

$$\text{accept} = \min\left(1, \frac{p(x) \cdot \text{PRM_score}(context+x)}{q(x) \cdot \text{PRM_score}(context)}\right)$$

This biases acceptance toward tokens that lead to better reasoning chains according to the PRM. Useful for reasoning-heavy tasks where the quality of the reasoning path matters, not just the marginal token probability.

Experimental results (Xia et al., 2024): On MATH benchmark, PRM-guided speculative decoding achieves 5-10% accuracy improvement over standard speculative decoding while maintaining comparable speedup. The tradeoff: no longer generates exact target distribution (PRM creates a biased distribution) — but for reasoning tasks, this bias toward better reasoning is desirable.

One thing to remember: Speculative decoding is the most mature and widely deployed LLM inference optimization, and its continued development (tree attention, dynamic drafting, PRM guidance) shows that the space of “same output, faster generation” optimizations has significant remaining room.

speculative-decodingtree-attentioneagle-2continuous-batchingmulti-candidateprm-verification

See Also

  • Knowledge Distillation How AI companies shrink massive models down to phone-sized ones without losing much intelligence — the teacher-student trick that powers on-device AI.
  • Model Pruning How AI models lose weight without losing intelligence — removing the neurons that don't actually do anything useful to make models faster and smaller.
  • Model Quantization How AI models get shrunk to run on your phone — the precision-tradeoff trick that makes 70 billion parameter models fit in consumer hardware.
  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.