Graph Neural Networks with Python — Core Concepts
Why standard neural networks fail on graphs
Convolutional neural networks (CNNs) assume a grid structure — every pixel has exactly 4 neighbors. Recurrent networks assume a sequence — each token has exactly one predecessor. Graphs have neither: a node might have 2 neighbors or 2,000. There’s no fixed ordering, no grid, no sequence.
GNNs solve this by defining operations that work regardless of neighborhood size. The core mechanism is message passing: each node collects information from its neighbors, aggregates it, and updates its own representation.
The message passing framework
Every GNN layer performs three steps:
- Message — Each neighbor prepares a message (usually a transformed version of its feature vector).
- Aggregate — The node collects all messages from neighbors. Aggregation must be order-invariant (sum, mean, or max).
- Update — The node combines its current features with the aggregated messages to produce a new feature vector.
Stack multiple layers to capture multi-hop information. After k layers, each node’s representation captures the structure of its k-hop neighborhood.
Key architectures
GCN (Graph Convolutional Network)
The simplest and most widely used. Each layer computes:
h_v = σ(W · MEAN(h_u for u in neighbors(v) ∪ {v}))
The node averages its neighbors’ features (including its own), applies a learnable weight matrix, and passes through an activation function. Introduced by Kipf and Welling (2017), GCN works well for homophilic graphs where connected nodes tend to share labels.
GraphSAGE
Instead of using all neighbors, GraphSAGE samples a fixed number of neighbors per node. This makes it scalable to large graphs and — critically — inductive: it can generate embeddings for nodes not seen during training.
Aggregation options include mean, LSTM (on a random ordering), and pooling (max of transformed features).
GAT (Graph Attention Network)
Not all neighbors are equally important. GAT learns attention weights that assign different importance to different neighbors. A node attending to its neighbors might weight its closest collaborator at 0.6 and a distant acquaintance at 0.1.
Attention is computed using a small neural network that takes both nodes’ features as input, making it adaptive to the specific node pair.
Task types
GNNs handle three levels of prediction:
Node classification
Predict a label for each node. Example: classifying users as bots or humans in a social network. The final node embeddings are passed through a classifier.
Edge prediction (link prediction)
Predict whether an edge should exist between two nodes. Example: recommending connections on LinkedIn. Score pairs of node embeddings and rank by likelihood.
Graph classification
Predict a label for an entire graph. Example: predicting whether a molecular graph is toxic. Requires a readout step that aggregates all node embeddings into a single graph-level vector (global mean pool, global max pool, or attention-based).
Python frameworks
Two dominant libraries:
- PyTorch Geometric (PyG) — Tightly integrated with PyTorch. Provides 70+ GNN layers, 100+ datasets, and utilities for mini-batching, sampling, and heterogeneous graphs.
- DGL (Deep Graph Library) — Framework-agnostic (works with PyTorch, TensorFlow, MXNet). Slightly more explicit message-passing API. Strong for heterogeneous graphs and large-scale training.
Both are production-quality and actively maintained.
Practical considerations
- Over-smoothing — Stacking too many GNN layers makes all node embeddings converge to the same vector. Two to three layers is typical; beyond four, use skip connections or jumping knowledge.
- Feature engineering — Node features matter enormously. Degree, centrality metrics, and one-hot encodings of categorical attributes all help.
- Mini-batching — Standard batch training doesn’t work because nodes share neighbors across the batch. Both PyG and DGL provide neighbor sampling loaders that create independent subgraphs per batch.
Common misconception
“GNNs are always better than simpler graph methods.” For many tasks, Node2Vec embeddings plus a random forest outperform GNNs — especially when you have limited labeled data. GNNs shine when you have rich node features and enough labels to train the network. Start simple, benchmark, then move to GNNs if the improvement justifies the complexity.
One thing to remember: GNNs extend deep learning to irregular graph structures through message passing — each node learns by listening to its neighbors.
See Also
- Python Community Detection How Python finds hidden groups in networks — friend circles, customer segments, and research clusters — just by looking at who connects to whom.
- Python Graph Embeddings How Python turns tangled webs of connections into neat lists of numbers that computers can actually work with.
- Python Link Prediction How Python guesses which connections are missing from a network — predicting future friendships, recommendations, and undiscovered relationships.
- Python Networkx Graph Analysis How Python maps connections between things — friends, roads, websites — and finds hidden patterns in those connections.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.