Contrastive Learning — Deep Dive
InfoNCE: Mutual Information Maximization
The NT-Xent loss used in SimCLR is a variant of InfoNCE (Noise Contrastive Estimation for mutual information), introduced in CPC (Contrastive Predictive Coding, van den Oord et al., 2018).
InfoNCE provides a lower bound on mutual information:
$$I(X; Y) \geq \mathbb{E}\left[\log \frac{e^{f(x, y)}}{\frac{1}{N}\sum_{j=1}^N e^{f(x, y_j)}}\right]$$
Where $f(x, y)$ is a scoring function (e.g., dot product of normalized representations), and ${y_j}$ are $N-1$ negatives plus 1 positive.
As $N \rightarrow \infty$, this bound approaches $\log N$. The implication: more negatives = tighter bound on mutual information = better representation quality. This is the theoretical justification for why SimCLR and MoCo benefit from large numbers of negatives.
The augmentation invariance interpretation: From an information-theoretic view, contrastive SSL learns representations $z = f(x)$ that maximize mutual information between $z$ and the original image $x$, subject to the constraint that augmented views produce similar representations. The augmentation policy defines which information is “irrelevant” (invariant to augmentation) and which is “useful” (shared across augmented views of the same image).
Uniformity-Alignment Analysis
Wang & Isola (2020) “Understanding Contrastive Representation Learning through Alignment and Uniformity” decomposed contrastive loss into two geometric properties:
Alignment: Normalized representations of positive pairs should be close: $$\mathcal{L}{align} = \mathbb{E}{(x,x^+) \sim p_{pos}}\left[|f(x) - f(x^+)|^2\right]$$
Uniformity: The overall distribution of representations on the unit hypersphere should be approximately uniform (maximum entropy distribution): $$\mathcal{L}{uniform} = \log \mathbb{E}{x, y \sim p_{data}}\left[e^{-2|f(x) - f(y)|^2}\right]$$
(Gaussian potential, measuring average pairwise similarity — want this to be small)
NT-Xent approximately minimizes alignment + uniformity simultaneously. This framework explains why contrastive learning produces useful representations: uniformity ensures the representation space isn’t wasted (no clustering of unrelated concepts), and alignment ensures invariance to within-class variation.
Empirically: as training progresses, alignment improves rapidly first, then uniformity improves. Both metrics correlate with downstream linear probe performance.
Collapse Analysis: When Contrastive Learning Fails
Dimensional collapse (partial collapse): Representations use only a subset of the available dimensions — concentrating variance in a few directions while most dimensions are unused. This wastes representational capacity.
Hua et al. (2021) showed that even SimCLR can suffer dimensional collapse at small batch sizes or with insufficient augmentation. Symptoms: low effective rank of the representation matrix.
Solutions:
- VICReg’s covariance regularization directly penalizes collapse by minimizing off-diagonal covariance
- Whitening: periodically whiten the representation matrix to redistribute variance
- Larger projection head dimensionality
Mode collapse in BYOL/momentum methods: Can occur if the momentum encoder lags too far behind — the target representations become stale, and the predictor learns a trivial mapping that happens to match the stale targets. Manifests as all representations converging to a small region. Prevention: keep momentum $m$ below 0.999 during early training.
Supervised Contrastive Learning
Khosla et al. (2020) “Supervised Contrastive Learning” extended contrastive SSL to the supervised setting. With class labels available, all examples of the same class are positive pairs (not just augmentations):
$$\mathcal{L}{SupCon} = \sum{i \in I} \frac{-1}{|P(i)|} \sum_{p \in P(i)} \log \frac{e^{z_i \cdot z_p / \tau}}{\sum_{a \in A(i)} e^{z_i \cdot z_a / \tau}}$$
Where $P(i)$ is the set of positives for anchor $i$ (same-class examples in the batch) and $A(i)$ is all other examples.
Supervised contrastive learning outperforms cross-entropy training in several settings:
- More robust to label noise (cluster structure doesn’t depend on a single label)
- Better transferability (representations are organized by semantic similarity, not arbitrary class indices)
- More consistent improvement from larger batches
Performance on ImageNet: SupCon ResNet-50: 78.7% vs. cross-entropy: 77.6% top-1 (1.1% improvement). Significant on a well-saturated benchmark.
Connection to Supervised Learning: The Contrastive-Softmax Duality
Standard cross-entropy classification can be rewritten in a contrastive form. The softmax loss for class $c$:
$$\mathcal{L}{CE} = -\log \frac{e^{w_c \cdot f(x)}}{\sum{c’} e^{w_{c’} \cdot f(x)}}$$
This is equivalent to InfoNCE where the “positives” are embeddings of the correct class prototype and “negatives” are embeddings of wrong class prototypes. The class weight vectors $w_c$ play the role of persistent negative keys.
This equivalence shows that supervised classification is implicitly doing contrastive learning — learning to align example representations with their class prototype while pushing away from other class prototypes. The distinction between “supervised” and “contrastive” learning is less fundamental than it appears.
Practical Implementation Notes
Effective contrastive learning requires careful implementation:
import torch
import torch.nn.functional as F
def nt_xent_loss(z1, z2, temperature=0.07):
# z1, z2: (batch_size, embed_dim), L2-normalized
batch_size = z1.shape[0]
# Concatenate all embeddings
z = torch.cat([z1, z2], dim=0) # (2N, D)
# Similarity matrix
sim = torch.mm(z, z.T) / temperature # (2N, 2N)
# Positive pairs: (i, i+N) and (i+N, i)
labels = torch.arange(batch_size, device=z.device)
labels = torch.cat([labels + batch_size, labels])
# Mask self-similarity
mask = torch.eye(2 * batch_size, device=z.device).bool()
sim.masked_fill_(mask, float('-inf'))
return F.cross_entropy(sim, labels)
Key implementation details:
- L2-normalize embeddings before computing similarity (cosine similarity)
- Gather negatives from all GPUs in distributed training (not just the local batch)
- Use
stop_gradientcorrectly in BYOL-style methods — gradient should flow only through the online network
One thing to remember: Contrastive learning’s deep connection to mutual information maximization explains why it works: it learns representations that are maximally informative about the data-generating factors while being invariant to augmentation-defined irrelevant variation — and this principle generalizes naturally across any domain where you can define meaningful positive pairs.
See Also
- Data Augmentation How AI systems make do with less data by creating variations of what they have — the training trick that prevented ImageNet models from memorizing training examples.
- Few Shot Learning How AI learned to learn from just a handful of examples — the technique that lets AI generalize like humans instead of needing millions of training samples.
- Lora Fine Tuning How AI companies adapt massive models to specific tasks by training only a tiny fraction of the parameters — the technique making custom AI affordable.
- Reinforcement Learning Fundamentals How AI learns from trial, error, and rewards — the technique that beat the world chess champion, solved protein folding, and is now teaching robots to walk.
- Self Supervised Learning How AI learned to teach itself from unlabeled data — the technique that let GPT and BERT learn from the entire internet without any human labeling.