Graph Neural Networks — Deep Dive
Expressiveness Limits: The Weisfeiler-Leman Test
A fundamental question about GNNs: when can two non-isomorphic graphs produce the same representation? If they can, the GNN can’t distinguish them.
The Weisfeiler-Leman (WL) graph isomorphism test is a classical algorithm for testing graph isomorphism:
- Assign identical colors to all nodes
- Iteratively update each node’s color based on its own color and a multiset of neighbor colors
- Two graphs are “WL-distinct” if the final color histograms differ
Xu et al. (2019) “How Powerful are Graph Neural Networks?” proved:
Theorem: Any message-passing GNN is at most as powerful as the WL test in distinguishing graph structures. Two non-isomorphic graphs that WL cannot distinguish will produce the same representation in any message-passing GNN.
Furthermore, a specific GNN achieves maximum power (equals WL): the Graph Isomorphism Network (GIN):
$$h_v^{(k)} = \text{MLP}^{(k)}\left((1 + \epsilon^{(k)}) h_v^{(k-1)} + \sum_{u \in \mathcal{N}(v)} h_u^{(k-1)}\right)$$
GIN’s sum aggregation (not mean or max) preserves the multiset structure, making it maximally powerful within the message-passing framework.
Beyond WL: Higher-order WL tests (k-WL) consider groups of $k$ nodes. $k$-dimensional GNNs can distinguish more graph structures but at $O(n^k)$ computational cost. 3-WL can distinguish almost all practical molecular graphs; 2-WL is tractable and significantly more expressive than 1-WL (standard MPNN).
Over-Smoothing and Over-Squashing
Over-smoothing: With deep GNNs (many message passing layers), all node representations converge to similar values. The node’s representation becomes dominated by the global graph signal rather than local structure.
Formally: with $k$ layers, each node aggregates from its $k$-hop neighborhood. For $k$ larger than the graph’s diameter, every node essentially sees the entire graph. The representations converge exponentially fast to the stationary distribution of a random walk on the graph.
Mitigation:
- Residual/skip connections across GNN layers (similar to ResNet)
- Normalization per node (LayerNorm)
- DropEdge: randomly remove edges during training to prevent over-smoothing
- GRAND (Graph Random Neural Networks): uses random propagation augmentation
Over-squashing: Long-range information must be squeezed through narrow graph bottlenecks — exponentially more information from farther neighborhoods must be compressed into the same-size representation.
The aggregation at a node after $k$ steps contains information from $O(\text{degree}^k)$ nodes compressed into a fixed-size vector. For a graph with maximum degree $d$ and $k=5$ layers: up to $d^5$ nodes’ information in one vector.
For tasks requiring long-range dependencies (e.g., protein function depending on spatially distant residues), MPNN is fundamentally limited. Graph transformers address this.
Graph Transformers
Applying attention directly over all pairs of nodes:
$$z_i = \sum_{j \in V} \alpha_{ij} V h_j, \quad \alpha_{ij} = \text{softmax}_j\left(\frac{(Q h_i) \cdot (K h_j)}{\sqrt{d_k}}\right)$$
This gives $O(n^2)$ complexity but allows direct long-range interaction without bottlenecks.
Positional encodings for graphs: Transformers use position encodings to inject structure. For graphs, there’s no canonical node ordering. Options:
- Laplacian eigenvectors: First $k$ eigenvectors of the graph Laplacian serve as node positional encodings
- Random walk positional encoding: $\text{PE}i = [\text{RW}{ii}^1, \text{RW}{ii}^2, …, \text{RW}{ii}^k]$ (landing probability at node $i$ after $k$ random walk steps from $i$)
- Graph distance encodings: Pairwise shortest-path distances between nodes
GraphGPS (Rampášek et al., 2022): Combines message passing (local, structure-aware) with attention (global, long-range) in each layer: $$h_v^{MPNN} = \text{MPNN_layer}({h_v, h_u : u \in \mathcal{N}(v)})$$ $$h_v^{out} = h_v^{MPNN} + \text{MultiHeadAttn}(h_v^{MPNN}, H^{MPNN})$$
GraphGPS consistently outperforms pure MPNN and pure transformer on long-range graph benchmarks.
Equivariant GNNs for 3D Molecular Structure
Standard GNNs ignore 3D atomic coordinates or treat them as features. For molecular property prediction (energy, forces, reaction rates), 3D geometry is essential.
The symmetry requirement: if you rotate or reflect a molecule, its properties don’t change. The network must be equivariant to E(3) group transformations:
$$f(R \cdot x) = R \cdot f(x)$$
For all rotations $R \in O(3)$ and translations $t \in \mathbb{R}^3$.
SE(3)-Transformer (Fuchs et al., 2020): Uses spherical harmonics as equivariant feature representations. Node features are tensors of different “types” (scalars, vectors, etc.) that transform predictably under rotation. Attention is computed equivariantly.
NequIP (Batzner et al., 2022): Uses E(3)-equivariant features for interatomic potential learning. Achieves much better data efficiency than non-equivariant models — learns from 1000x fewer molecular dynamics trajectories while matching accuracy.
MACE (Batatia et al., 2022): Multi-Atomic Cluster Expansion with equivariant message passing. Became a leading foundation model for molecular simulations, with MACE-MP-0 (2023) providing a universal pretrained force field for chemistry.
AlphaFold’s IPA: Invariant Point Attention computes attention weights using invariant features (distances, angles) while maintaining equivariant updates to atom frame orientations. The combination allows AlphaFold to predict protein 3D structure with near-experimental accuracy.
Scalability: Sampling Strategies for Large Graphs
Full-batch GNN training requires loading the entire graph into GPU memory — infeasible for billion-node graphs.
Neighbor sampling (GraphSAGE): Sample a fixed number of neighbors at each hop. For a 2-layer GNN with 5 neighbors sampled per hop: at most 25 nodes per mini-batch item (vs. potentially thousands with full neighborhoods).
Layer sampling (FastGCN, Chen et al., 2018): Sample a fixed number of nodes per layer globally (not per node). Enables uniform sampling but may sample disconnected nodes per mini-batch.
Cluster-GCN (Chiang et al., 2019): Partition the graph into clusters (using METIS or similar), then sample mini-batches from one cluster at a time. Within a cluster, full-batch GCN training is feasible and edges are dense. Inter-cluster edges are discarded for mini-batch computation but can be handled separately.
Cluster-GCN at scale: Used by Pinterest for PinSage on their 3 billion node graph. The clustering approach means each mini-batch has a coherent neighborhood, reducing variance in gradient estimates.
One thing to remember: GNNs are fundamentally limited by the WL test in expressiveness, over-squashing in long-range communication, and neighborhood explosion in scalability — understanding these limits guides which GNN variant to choose for a given problem.
See Also
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
- Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
- Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'