Graph Neural Networks with Python — Deep Dive

Setting up PyTorch Geometric

PyG requires PyTorch as a backend. Installation depends on your CUDA version:

pip install torch torchvision
pip install torch-geometric
# For GPU sparse operations:
pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.2.0+cu121.html

GCN implementation

A full GCN for node classification on the Cora citation dataset:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root="/tmp/Cora", name="Cora")
data = dataset[0]

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN(dataset.num_features, 64, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = (pred[mask] == data.y[mask]).sum()
        accs.append(int(correct) / int(mask.sum()))
    return accs

for epoch in range(200):
    loss = train()
    if epoch % 20 == 0:
        train_acc, val_acc, test_acc = evaluate()
        print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, "
              f"Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}")

Expected performance: ~81% test accuracy with this basic 2-layer GCN on Cora.

GraphSAGE with neighbor sampling

For large graphs that don’t fit in GPU memory, use mini-batch training with neighbor sampling:

from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return x

# Sample 15 neighbors in layer 1, 10 in layer 2
train_loader = NeighborLoader(
    data,
    num_neighbors=[15, 10],
    batch_size=1024,
    input_nodes=data.train_mask,
    shuffle=True,
)

model = GraphSAGE(dataset.num_features, 256, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train_minibatch():
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        # Only compute loss on seed nodes (first batch_size nodes)
        loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

Sampling strategies compared

StrategyProsCons
Uniform neighbor samplingSimple, fastMay miss important neighbors
Importance samplingFocuses on high-degree nodesBiased gradients need correction
Cluster-GCNBatch entire subgraphsInter-cluster edges are lost
GraphSAINTUnbiased estimationHigher variance

GAT with multi-head attention

from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        # Output layer: single head, concat=False averages heads
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1,
                             concat=False, dropout=0.6)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

GAT typically uses 8 attention heads in hidden layers and 1 head in the output layer. Each head learns different importance patterns.

Heterogeneous graphs

Real-world graphs have multiple node and edge types. PyG handles this with HeteroData:

from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv

data = HeteroData()
data["user"].x = torch.randn(1000, 64)
data["item"].x = torch.randn(5000, 128)
data["user", "buys", "item"].edge_index = torch.randint(0, 1000, (2, 10000))
data["user", "follows", "user"].edge_index = torch.randint(0, 1000, (2, 5000))

class HeteroGNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = HeteroConv({
            ("user", "buys", "item"): SAGEConv((-1, -1), 64),
            ("user", "follows", "user"): SAGEConv((-1, -1), 64),
        })
        self.conv2 = HeteroConv({
            ("user", "buys", "item"): SAGEConv((-1, -1), 32),
            ("user", "follows", "user"): SAGEConv((-1, -1), 32),
        })

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def encode(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        src, dst = edge_label_index
        return (z[src] * z[dst]).sum(dim=1)  # dot product

    def forward(self, x, edge_index, edge_label_index):
        z = self.encode(x, edge_index)
        return self.decode(z, edge_label_index)

Debugging GNN training

Common failure modes

Loss doesn’t decrease: Check that edge_index is properly formatted (2 × E, long tensor). A common bug is transposing it.

All predictions are the same class: This happens with over-smoothing (too many layers) or when the graph is disconnected and some components have no labeled nodes.

Out-of-memory on GPU: Reduce num_neighbors in the sampler, decrease hidden dimensions, or use torch.cuda.amp for mixed-precision training.

Visualization for debugging

from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx

G = to_networkx(data, to_undirected=True)
pos = nx.spring_layout(G, seed=42)

# Color nodes by prediction
model.eval()
with torch.no_grad():
    pred = model(data.x, data.edge_index).argmax(dim=1).numpy()

nx.draw(G, pos, node_color=pred, cmap=plt.cm.Set1, node_size=20, width=0.3)
plt.savefig("gnn_predictions.png", dpi=200)

Performance benchmarks

Typical results on standard benchmarks (Cora dataset, 2708 nodes, 5429 edges):

ModelTest AccuracyParametersTraining Time
GCN (2 layer)81.5%23K2s
GraphSAGE82.0%46K3s
GAT (8 heads)83.0%92K5s
Node2Vec + LR78.0%10s

On larger datasets (ogbn-products, 2.4M nodes), mini-batch GraphSAGE with 3 layers achieves ~79% accuracy while fitting in 16 GB GPU memory.

One thing to remember: Start with a 2-layer GCN as your baseline. If it works, try GAT for attention-weighted neighbors or GraphSAGE for scalability. Most real improvement comes from better features and data quality, not from architectural complexity.

pythonmachine-learningdeep-learning

See Also

  • Python Community Detection How Python finds hidden groups in networks — friend circles, customer segments, and research clusters — just by looking at who connects to whom.
  • Python Graph Embeddings How Python turns tangled webs of connections into neat lists of numbers that computers can actually work with.
  • Python Link Prediction How Python guesses which connections are missing from a network — predicting future friendships, recommendations, and undiscovered relationships.
  • Python Networkx Graph Analysis How Python maps connections between things — friends, roads, websites — and finds hidden patterns in those connections.
  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.