Chain-of-Thought Reasoning — Deep Dive

Why Chain-of-Thought Works: Mechanistic Hypotheses

The empirical effect is clear; the mechanism is more contested. Several hypotheses:

Hypothesis 1: Additional Computation Tokens

Each generated token is an additional transformer forward pass. For a complex multi-step calculation, each reasoning step provides computational “scratch space.” The model can compute partial results, store them in generated text, and use them for subsequent steps.

Without CoT: the model must perform multi-step reasoning in a single forward pass — essentially computing $f(f(f(x)))$ without access to intermediate $f(x)$ or $f(f(x))$.

With CoT: intermediate computations are explicitly represented in generated text, which the model can attend to in subsequent forward passes. The model’s effective “working memory” is extended by the length of the reasoning chain.

This implies: CoT should help more with problems requiring sequential computation (arithmetic, logic), less with problems requiring knowledge retrieval or pattern matching. This is consistent with empirical results.

Hypothesis 2: Distribution Shift to Reasoning-Dense Training Data

“Let’s think step by step” and related phrases may shift the output distribution toward text patterns in the training data that are solution-dense. Mathematical textbooks, coding tutorials, and scientific papers contain detailed step-by-step derivations. Triggering this distribution might cause the model to generate text similar to these sources — which happen to be correct.

Wei et al.’s analysis: for small models (< 100B parameters), CoT doesn’t help or sometimes hurts — the model doesn’t have the capability to benefit from the additional computation. For large models, the benefit is consistent. This is consistent with hypothesis 1 (larger models do more useful computation per step) and hypothesis 2 (larger models have better representations of reasoning-dense training text).

Hypothesis 3: Anchor Effect

Generating explicit intermediate conclusions anchors subsequent generation to be consistent with those conclusions. Without CoT, the final answer generation is unconstrained by intermediate steps. With CoT, the model’s answer must be consistent with its stated reasoning — reducing the probability of random errors in the final generation.

Self-Consistency: Formal Analysis

Self-consistency takes majority vote over $K$ samples. The formal justification: if each reasoning chain independently reaches the correct answer with probability $p$, and wrong answers are spread across many possibilities (no single dominant wrong answer), then majority vote converges to the correct answer as $K$ increases.

For a question with $m$ possible answers, if $p > 1/m$, the majority vote accuracy exceeds $p$ for large enough $K$:

$$P(\text{majority correct}) = \sum_{k > K/2}^K \binom{K}{k} p^k (1-p)^{K-k}$$

For $p = 0.6$, $m = 4$, $K = 40$: $P(\text{majority correct}) \approx 0.97$. Starting from 60% individual accuracy, you reach 97% accuracy with 40 samples.

Universality assumption: The derivation assumes wrong answers are uniformly distributed. In practice, some wrong answers are more common than others (systematic errors, near-miss mistakes). Empirically, wrong answers are more distributed than correct answers — reasoning chains that reach wrong answers tend to take different wrong paths, while correct paths converge.

Cost-accuracy tradeoff: Self-consistency at $K=40$ costs 40x more inference compute. Using adaptive sampling (stop when a majority is clear, up to K if not) reduces expected cost while preserving accuracy.

Process Reward Models vs. Outcome Reward Models

Standard RLHF trains reward models on outcomes — is the final answer correct? This is the Outcome Reward Model (ORM).

Process Reward Models (PRMs) (Lightman et al., OpenAI, 2023) rate individual steps in the reasoning chain — not just the final answer, but each intermediate step’s correctness.

For math problems, human annotators labeled each reasoning step as correct, neutral, or incorrect. A PRM is trained on this step-level signal. At inference, the PRM provides a “value function” for reasoning steps — the expected probability of reaching the correct answer given the current state.

PRMs enable:

  1. Better supervision signal: Guide model training toward correct reasoning, not just correct answers (which can be reached by incorrect reasoning that happens to get the right answer)
  2. Best-of-N selection: Generate N reasoning chains, select the one with highest PRM score (rather than random selection in self-consistency)
  3. Beam search over reasoning: Use PRM as the heuristic in tree search, expanding only high-scoring branches

Let’s verify step by step (Uesato et al., DeepMind, 2022): PRMs significantly outperform ORMs at the same number of generated samples — fewer samples needed to reach the same accuracy because PRM-guided selection is more efficient.

Test-Time Compute Scaling

Snell et al. (2024) “Scaling LLM Test-Time Compute Optimally Can Be More Effective Than Scaling Model Parameters” formalized the observation that inference-time compute can substitute for model parameters for certain problems.

The scaling law: For a fixed total compute budget, sometimes it’s better to use a smaller model with more sampling/reasoning than a larger model with less:

$$\text{Best accuracy at compute budget C} = \max_{\text{model size, reasoning budget}} \text{accuracy}$$

For hard math problems: a 7B parameter model with 100 reasoning samples can outperform a 70B model with 1 sample, at similar total compute.

This has significant practical implications:

  • Inference-time compute scaling is an additional axis (beyond training compute and model size)
  • “Reasoning models” (o1, o3) explicitly leverage this: they scale inference compute dynamically based on problem difficulty
  • Cost-performance tradeoff can be optimized: use cheap model + lots of reasoning for hard problems, expensive model + single forward pass for easy problems

Compute-optimal inference: For a fixed accuracy target, compute-optimal inference uses:

  • Easy problems: fast, single forward pass
  • Hard problems: extended reasoning, possibly with tree search
  • Adaptive allocation based on detected problem difficulty

This mirrors human cognition: System 1 (fast, automatic) vs. System 2 (slow, deliberate) thinking.

One thing to remember: Chain-of-thought wasn’t just a prompting trick — it revealed that intelligence scales differently at inference time than at training time, opening a new dimension of AI capability improvement that is only beginning to be exploited.

chain-of-thoughtlatent-tokenprocess-reward-modeltest-time-scalingmcts-reasoning

See Also

  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
  • Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'