Graph Neural Networks with Python — Deep Dive
Setting up PyTorch Geometric
PyG requires PyTorch as a backend. Installation depends on your CUDA version:
pip install torch torchvision
pip install torch-geometric
# For GPU sparse operations:
pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.2.0+cu121.html
GCN implementation
A full GCN for node classification on the Cora citation dataset:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root="/tmp/Cora", name="Cora")
data = dataset[0]
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GCN(dataset.num_features, 64, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def evaluate():
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = (pred[mask] == data.y[mask]).sum()
accs.append(int(correct) / int(mask.sum()))
return accs
for epoch in range(200):
loss = train()
if epoch % 20 == 0:
train_acc, val_acc, test_acc = evaluate()
print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, "
f"Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}")
Expected performance: ~81% test accuracy with this basic 2-layer GCN on Cora.
GraphSAGE with neighbor sampling
For large graphs that don’t fit in GPU memory, use mini-batch training with neighbor sampling:
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x
# Sample 15 neighbors in layer 1, 10 in layer 2
train_loader = NeighborLoader(
data,
num_neighbors=[15, 10],
batch_size=1024,
input_nodes=data.train_mask,
shuffle=True,
)
model = GraphSAGE(dataset.num_features, 256, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
def train_minibatch():
model.train()
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
# Only compute loss on seed nodes (first batch_size nodes)
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
Sampling strategies compared
| Strategy | Pros | Cons |
|---|---|---|
| Uniform neighbor sampling | Simple, fast | May miss important neighbors |
| Importance sampling | Focuses on high-degree nodes | Biased gradients need correction |
| Cluster-GCN | Batch entire subgraphs | Inter-cluster edges are lost |
| GraphSAINT | Unbiased estimation | Higher variance |
GAT with multi-head attention
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
# Output layer: single head, concat=False averages heads
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1,
concat=False, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
GAT typically uses 8 attention heads in hidden layers and 1 head in the output layer. Each head learns different importance patterns.
Heterogeneous graphs
Real-world graphs have multiple node and edge types. PyG handles this with HeteroData:
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv
data = HeteroData()
data["user"].x = torch.randn(1000, 64)
data["item"].x = torch.randn(5000, 128)
data["user", "buys", "item"].edge_index = torch.randint(0, 1000, (2, 10000))
data["user", "follows", "user"].edge_index = torch.randint(0, 1000, (2, 5000))
class HeteroGNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = HeteroConv({
("user", "buys", "item"): SAGEConv((-1, -1), 64),
("user", "follows", "user"): SAGEConv((-1, -1), 64),
})
self.conv2 = HeteroConv({
("user", "buys", "item"): SAGEConv((-1, -1), 32),
("user", "follows", "user"): SAGEConv((-1, -1), 32),
})
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
Link prediction with GNNs
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
class LinkPredictor(torch.nn.Module):
def __init__(self, in_channels, hidden_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
def encode(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
return self.conv2(x, edge_index)
def decode(self, z, edge_label_index):
src, dst = edge_label_index
return (z[src] * z[dst]).sum(dim=1) # dot product
def forward(self, x, edge_index, edge_label_index):
z = self.encode(x, edge_index)
return self.decode(z, edge_label_index)
Debugging GNN training
Common failure modes
Loss doesn’t decrease: Check that edge_index is properly formatted (2 × E, long tensor). A common bug is transposing it.
All predictions are the same class: This happens with over-smoothing (too many layers) or when the graph is disconnected and some components have no labeled nodes.
Out-of-memory on GPU: Reduce num_neighbors in the sampler, decrease hidden dimensions, or use torch.cuda.amp for mixed-precision training.
Visualization for debugging
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx
G = to_networkx(data, to_undirected=True)
pos = nx.spring_layout(G, seed=42)
# Color nodes by prediction
model.eval()
with torch.no_grad():
pred = model(data.x, data.edge_index).argmax(dim=1).numpy()
nx.draw(G, pos, node_color=pred, cmap=plt.cm.Set1, node_size=20, width=0.3)
plt.savefig("gnn_predictions.png", dpi=200)
Performance benchmarks
Typical results on standard benchmarks (Cora dataset, 2708 nodes, 5429 edges):
| Model | Test Accuracy | Parameters | Training Time |
|---|---|---|---|
| GCN (2 layer) | 81.5% | 23K | 2s |
| GraphSAGE | 82.0% | 46K | 3s |
| GAT (8 heads) | 83.0% | 92K | 5s |
| Node2Vec + LR | 78.0% | — | 10s |
On larger datasets (ogbn-products, 2.4M nodes), mini-batch GraphSAGE with 3 layers achieves ~79% accuracy while fitting in 16 GB GPU memory.
One thing to remember: Start with a 2-layer GCN as your baseline. If it works, try GAT for attention-weighted neighbors or GraphSAGE for scalability. Most real improvement comes from better features and data quality, not from architectural complexity.
See Also
- Python Community Detection How Python finds hidden groups in networks — friend circles, customer segments, and research clusters — just by looking at who connects to whom.
- Python Graph Embeddings How Python turns tangled webs of connections into neat lists of numbers that computers can actually work with.
- Python Link Prediction How Python guesses which connections are missing from a network — predicting future friendships, recommendations, and undiscovered relationships.
- Python Networkx Graph Analysis How Python maps connections between things — friends, roads, websites — and finds hidden patterns in those connections.
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.