Few-Shot Learning — Deep Dive

Episodic Training: Aligning Train and Test Distributions

The critical insight behind metric-learning approaches: the distribution of training and test tasks must be aligned. If you train on “classify between cats and dogs” and test on “classify between 5 unseen bird species from 1 example each,” the model may not generalize.

Episodic training (Vinyals et al., 2016, Matching Networks) reframes the training problem to match the test problem:

During training, each “episode” samples:

  1. N classes from the training class set
  2. K support examples per class (the few-shot examples)
  3. Q query examples (the evaluation examples)

The model processes the support set, then classifies query examples. The loss is the query classification loss. Weight updates happen across many episodes.

This trains the model to do exactly what it will need to do at test time: use a small support set to classify queries. The training classes and test classes are disjoint, but the task structure is identical.

Why not just train on all data at once? Episodic training teaches the model to use the support set as context for the query. Non-episodic training teaches the model to classify using its weights. At test time, only episodic training prepares the model to exploit new examples it’s never seen.

Relation Networks and Learnable Similarity

Prototypical Networks use fixed Euclidean distance. What if the right similarity metric for the task isn’t Euclidean?

Relation Networks (Sung et al., 2018): Concatenate query and prototype embeddings, pass through a relation module (learned MLP) that outputs a similarity score. The relation module learns what makes two examples similar for the current task.

$$r(q, p_c) = g_\psi([f_\phi(q), f_\phi(p_c)])$$

Where $g_\psi$ is the relation module. This allows task-specific similarity metrics to emerge from training.

Cross-attention matching (newer approaches): Use transformer cross-attention to allow query examples to attend over support examples — learning to focus on discriminative features relative to what’s in the support set. This is the direction that achieves state-of-the-art on standard few-shot benchmarks.

MAML Second-Order Optimization Detail

MAML optimizes: $$\theta^* = \arg\min_\theta \mathcal{L}{out}(\theta - \alpha \nabla\theta \mathcal{L}_{in}(\theta))$$

The gradient with respect to $\theta$ requires differentiating through the inner gradient step — computing $\nabla_\theta \mathcal{L}{out}(\theta’)$ where $\theta’ = \theta - \alpha \nabla\theta \mathcal{L}_{in}(\theta)$:

$$\nabla_\theta \mathcal{L}{out}(\theta’) = \nabla{\theta’} \mathcal{L}{out} \cdot \nabla\theta \theta’ = \nabla_{\theta’} \mathcal{L}{out} \cdot (I - \alpha \nabla^2\theta \mathcal{L}_{in})$$

The Hessian term $\nabla^2_\theta \mathcal{L}_{in}$ requires second-order derivatives — expensive for large models.

First-Order MAML (FOMAML): Drop the Hessian term: $$\nabla_\theta \mathcal{L}{out}(\theta’) \approx \nabla{\theta’} \mathcal{L}_{out}$$

This approximation reduces compute by ~2x and loses less quality than expected, because the Hessian terms are often small in practice (the loss surface is locally nearly linear).

iMAML (Rajeswaran et al., 2019): Instead of explicit inner gradient steps, solve the inner optimization approximately using conjugate gradient, then use the implicit function theorem to compute outer gradients. More accurate than FOMAML but more complex to implement.

Induction Heads and In-Context Learning Mechanism

Olsson et al. (2022) “In-context Learning and Induction Heads” identified the specific mechanism enabling ICL in transformers.

Induction heads: Attention heads that implement the function: given $[A, B, …, A]$, predict $B$. They work by:

  1. Previous token head: copies information from previous token’s position to current token
  2. Induction head: attends back to where the same token appeared before, copies the token that came after

This allows the model to complete patterns: given few-shot examples [(in1, out1), (in2, out2), (in3, ?)], induction heads identify in3 as similar to previous inputs and predict the corresponding output.

The induction head forms at a phase transition: in training, at ~1B steps for GPT-2 scale models, there’s a sudden drop in loss corresponding to induction head formation. Before this point, ICL ability is limited; after, it improves dramatically.

The scale dependence: Induction heads explain basic ICL (pattern completion). More sophisticated ICL (multi-step reasoning, novel operations) requires additional mechanisms that only emerge at larger scale — the progressive emergence of capabilities documented in GPT-3 and beyond.

Chain-of-Thought as Few-Shot Learning

Wei et al. (2022) “Chain-of-Thought Prompting Elicits Reasoning in Large Language Models” showed that providing reasoning examples in few-shot prompts dramatically improves performance on complex tasks.

Standard few-shot:

Q: Roger has 5 tennis balls. He buys 2 more cans of 3 balls each. How many does he have?
A: 11
Q: The cafeteria had 23 apples. 20 were used for lunch, 6 more were bought. How many remain?
A:

Chain-of-thought few-shot:

Q: Roger has 5 tennis balls. He buys 2 more cans of 3 balls each. How many does he have?
A: Roger starts with 5 balls. 2 cans × 3 balls = 6 balls. Total: 5 + 6 = 11. The answer is 11.
Q: The cafeteria had 23 apples. 20 were used for lunch, 6 more were bought. How many remain?
A:

Accuracy on math word problems (GSM8K): standard few-shot ~19%, chain-of-thought few-shot ~57% (GPT-3 davinci). This is few-shot learning of reasoning process, not just answer format.

Zero-shot CoT: “Let’s think step by step” appended to a prompt elicits chain-of-thought reasoning without examples. This suggests the reasoning capability exists in the pretrained model; the prompt activates it.

The theoretical explanation: CoT examples shift the distribution of model outputs toward sequences that contain intermediate computation steps. The model’s pretraining on mathematical and logical text means it has learned to generate step-by-step solutions — CoT few-shot examples activate this behavior.

Scaling Laws for Few-Shot Learning

How does ICL performance scale with model size and context length?

Model size: ICL accuracy on MMLU and similar benchmarks increases approximately log-linearly with parameters. Doubling model size gives ~5–10% relative improvement on hard tasks.

Number of examples: ICL accuracy increases with K (examples per class) up to context length limits, but with diminishing returns. Going from 0-shot to 5-shot often gives 15–30% improvement; 5-shot to 20-shot gives 5–10% more.

Context length: For retrieval-augmented few-shot, longer contexts allow more examples. Models with 128k context windows (GPT-4-Turbo, Claude 3) can accommodate hundreds of few-shot examples — approaching the performance of fine-tuned models for some tasks.

One thing to remember: Few-shot learning is converging on two complementary approaches — powerful pretrained models with in-context learning for flexibility, and lightweight parameter-efficient fine-tuning (LoRA) when task-specific performance is critical — and the line between them is increasingly blurred.

few-shot-learningepisodic-traininginduction-headschain-of-thoughtlorascaling

See Also

  • Contrastive Learning How AI learns what things are like each other — and what they're not — without any labels, creating the representations behind image search and face recognition.
  • Data Augmentation How AI systems make do with less data by creating variations of what they have — the training trick that prevented ImageNet models from memorizing training examples.
  • Lora Fine Tuning How AI companies adapt massive models to specific tasks by training only a tiny fraction of the parameters — the technique making custom AI affordable.
  • Reinforcement Learning Fundamentals How AI learns from trial, error, and rewards — the technique that beat the world chess champion, solved protein folding, and is now teaching robots to walk.
  • Self Supervised Learning How AI learned to teach itself from unlabeled data — the technique that let GPT and BERT learn from the entire internet without any human labeling.