Files
gh-tachyon-beep-skillpacks-…/skills/using-neural-architectures/graph-neural-networks-basics.md
2025-11-30 09:00:00 +08:00

15 KiB
Raw Blame History

Graph Neural Networks Basics

When to Use This Skill

Use this skill when you need to:

  • Work with graph-structured data (molecules, social networks, citations)
  • Understand why CNN/RNN don't work on graphs
  • Learn message passing framework
  • Choose between GCN, GraphSAGE, GAT
  • Decide if GNN is appropriate (vs simple model)
  • Implement permutation-invariant aggregations

Do NOT use this skill for:

  • Sequential data (use RNN/Transformer)
  • Grid data (use CNN)
  • High-level architecture selection (use using-neural-architectures)

Core Principle

Graphs have irregular structure. CNN (grid) and RNN (sequence) don't work.

GNN solution: Message passing

  • Nodes aggregate information from neighbors
  • Multiple layers = multi-hop neighborhoods
  • Permutation invariant (order doesn't matter)

Critical question: Does graph structure actually help? (Test: Compare with/without edges)

Part 1: Why GNN (Not CNN/RNN)

Problem: Graph Structure

Graph components:

  • Nodes: Entities (atoms, users, papers)
  • Edges: Relationships (bonds, friendships, citations)
  • Features: Node/edge attributes

Key property: Irregular structure

  • Each node has variable number of neighbors
  • No fixed spatial arrangement
  • Permutation invariant (node order doesn't matter)

Why CNN Doesn't Work

CNN assumption: Regular grid structure

Example: Image (2D grid)

Every pixel has exactly 8 neighbors:
[■][■][■]
[■][X][■]  ← Center pixel has 8 neighbors (fixed!)
[■][■][■]

CNN kernel: 3×3 (fixed size, fixed positions)

Graph reality: Irregular neighborhoods

Node A: 2 neighbors (H, C)
Node B: 4 neighbors (C, C, C, H)
Node C: 1 neighbor (H)

No fixed kernel size or position!

CNN limitations:

  • Requires fixed-size neighborhoods → Graphs have variable-size
  • Assumes spatial locality → Graphs have arbitrary connectivity
  • Depends on node ordering → Should be permutation invariant

Why RNN Doesn't Work

RNN assumption: Sequential structure

Example: Text (1D sequence)

"The cat sat" → [The] → [cat] → [sat]
Clear sequential order, temporal dependencies

Graph reality: No inherent sequence

Social network:
A — B — C
|       |
D ——————E

What's the "sequence"? A→B→C? A→D→E? No natural ordering!

RNN limitations:

  • Requires sequential order → Graphs have no natural order
  • Processes one element at a time → Graphs have parallel connections
  • Order-dependent → Should be permutation invariant

GNN Solution

Key innovation: Message passing on graph structure

  • Operate directly on nodes and edges
  • Variable-size neighborhoods (handled naturally)
  • Permutation invariant aggregations

Part 2: Message Passing Framework

Core Mechanism

Message passing in 3 steps:

1. Aggregate neighbor messages

# Node i aggregates from neighbors N(i)
messages = [h_j for j in neighbors(i)]
aggregated = aggregate(messages)  # e.g., mean, sum, max

2. Update node representation

# Combine own features with aggregated messages
h_i_new = update(h_i_old, aggregated)  # e.g., neural network

3. Repeat for L layers

  • Layer 1: Node sees 1-hop neighbors
  • Layer 2: Node sees 2-hop neighbors
  • Layer L: Node sees L-hop neighborhood

Concrete Example: Social Network

Task: Predict user interests

Graph:

     B (sports)
     |
A ---+--- C (cooking)
     |
     D (music)

Layer 1: 1-hop neighbors

# Node A aggregates from direct friends
h_A_layer1 = update(
    h_A,
    aggregate([h_B, h_C, h_D])
)
# Now h_A includes friend interests

Layer 2: 2-hop neighbors (friends of friends)

# B's friends: E, F
# C's friends: G, H
# D's friends: I

h_A_layer2 = update(
    h_A_layer1,
    aggregate([h_B', h_C', h_D'])  # h_B' includes E, F
)
# Now h_A includes friends-of-friends!

Key insight: More layers = larger receptive field (L-hop neighborhood)

Permutation Invariance

Critical property: Same graph → same output (regardless of node ordering)

Example:

Graph: A-B, B-C

Node list 1: [A, B, C]
Node list 2: [C, B, A]

Output MUST be identical! (Same graph, different ordering)

Invariant aggregations:

  • Mean: mean([1, 2, 3]) == mean([3, 2, 1])
  • Sum: sum([1, 2, 3]) == sum([3, 2, 1])
  • Max: max([1, 2, 3]) == max([3, 2, 1])

NOT invariant:

  • LSTM: LSTM([1, 2, 3]) != LSTM([3, 2, 1])
  • Concatenate: [1, 2, 3] != [3, 2, 1]

Implementation:

# CORRECT: Permutation invariant
def aggregate(neighbor_features):
    return torch.mean(neighbor_features, dim=0)

# WRONG: Order-dependent!
def aggregate(neighbor_features):
    return LSTM(neighbor_features)  # Output depends on order

Part 3: GNN Architectures

Architecture 1: GCN (Graph Convolutional Network)

Key idea: Spectral convolution on graphs (simplified)

Formula:

h_i^(l+1) = σ(_{jN(i)} W^(l) h_j^(l) / (|N(i)| |N(j)|))

# Normalize by degree (√(deg(i) * deg(j)))

Aggregation: Weighted mean (degree-normalized)

Properties:

  • Transductive (needs full graph at training)
  • Computationally efficient
  • Good baseline

When to use:

  • Full graph available at training time
  • Starting point (simplest GNN)
  • Small to medium graphs

Implementation:

from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # x: Node features (N, in_channels)
        # edge_index: Graph connectivity (2, E)

        # Layer 1
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)

        # Layer 2
        x = self.conv2(x, edge_index)

        return x

Architecture 2: GraphSAGE

Key idea: Sample and aggregate (inductive learning)

Formula:

# Sample fixed-size neighborhood
neighbors_sampled = sample(neighbors(i), k=10)

# Aggregate
h_N = aggregate({h_j for j in neighbors_sampled})

# Concatenate and transform
h_i^(l+1) = σ(W^(l) [h_i^(l); h_N])

Aggregation: Mean, max, or LSTM (but mean/max preferred for invariance)

Key innovation: Sampling

  • Sample fixed number of neighbors (e.g., 10)
  • Makes computation tractable for large graphs
  • Enables inductive learning (generalizes to unseen nodes)

When to use:

  • Large graphs (millions of nodes)
  • Need inductive capability (new nodes appear)
  • Training on subset, testing on full graph

Implementation:

from torch_geometric.nn import SAGEConv

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

Architecture 3: GAT (Graph Attention Network)

Key idea: Learn attention weights for neighbors

Formula:

# Attention scores
α_ij = attention(h_i, h_j)  # How important is neighbor j to node i?

# Normalize (softmax)
α_ij = softmax_j(α_ij)

# Weighted aggregation
h_i^(l+1) = σ(_{jN(i)} α_ij W h_j^(l))

Key innovation: Learned neighbor importance

  • Not all neighbors equally important
  • Attention mechanism decides weights
  • Multi-head attention (like Transformer)

When to use:

  • Neighbors have varying importance
  • Need interpretability (attention weights)
  • Have sufficient data (attention needs more data to learn)

Implementation:

from torch_geometric.nn import GATConv

class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

Architecture Comparison

Feature GCN GraphSAGE GAT
Aggregation Degree-weighted mean Mean/max/LSTM Attention-weighted
Neighbor weighting Fixed (by degree) Equal Learned
Inductive No Yes Yes
Scalability Medium High (sampling) Medium
Interpretability Low Low High (attention)
Complexity Low Medium High

Decision Tree

Starting out / Small graph:
→ GCN (simplest baseline)

Large graph (millions of nodes):
→ GraphSAGE (sampling enables scalability)

Need inductive learning (new nodes):
→ GraphSAGE or GAT

Neighbors have different importance:
→ GAT (attention learns importance)

Need interpretability:
→ GAT (attention weights explain predictions)

Production deployment:
→ GraphSAGE (most robust and scalable)

Part 4: When NOT to Use GNN

Critical Question

Does graph structure actually help?

Test: Compare model with and without edges

# Baseline: MLP on node features only
mlp_accuracy = train_mlp(node_features, labels)

# GNN: Use node features + graph structure
gnn_accuracy = train_gnn(node_features, edges, labels)

# Decision:
if gnn_accuracy - mlp_accuracy < 2%:
    print("Graph structure doesn't help much")
    print("Use simpler model (MLP or XGBoost)")
else:
    print("Graph structure adds value")
    print("Use GNN")

Scenarios Where GNN Doesn't Help

1. Node features dominate

User churn prediction:
- Node features: Usage hours, demographics, subscription → Highly predictive
- Graph edges: Sparse user interactions → Weak signal
- Result: MLP 85%, GNN 86% (not worth complexity!)

2. Sparse graphs

Graph with 1000 nodes, 100 edges (0.01% density):
- Most nodes have 0-1 neighbors
- No information to aggregate
- GNN reduces to MLP

3. Random graph structure

If edges are random (no homophily):
- Neighbor labels uncorrelated
- Aggregation adds noise
- Simple model better

When GNN DOES Help

Molecular property prediction

  • Structure is PRIMARY signal
  • Atom types + bonds determine properties
  • GNN: Huge improvement over fingerprints

Citation networks

  • Paper quality correlated with neighbors
  • "You are what you cite"
  • Clear homophily

Social recommendation

  • Friends have similar preferences
  • Graph structure informative
  • GNN: Moderate to large improvement

Knowledge graphs

  • Entities connected by relations
  • Multi-hop reasoning valuable
  • GNN captures complex patterns

Decision Framework

1. Start simple:
   - Try MLP or XGBoost on node features
   - Establish baseline performance

2. Check graph structure value:
   - Does edge information correlate with target?
   - Is there homophily (similar nodes connected)?
   - Test: Remove edges, compare performance

3. Use GNN if:
   - Graph structure adds >2-5% accuracy
   - Structure is interpretable (not random)
   - Have enough nodes for GNN to learn

4. Stick with simple if:
   - Node features alone sufficient
   - Graph structure weak/random
   - Small dataset (< 1000 nodes)

Part 5: Practical Implementation

Using PyTorch Geometric

Installation:

pip install torch-geometric

Basic workflow:

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

# 1. Create graph data
x = torch.tensor([[feature1], [feature2], ...])  # Node features
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]])  # Edges (COO format)
y = torch.tensor([label1, label2, ...])  # Node labels

data = Data(x=x, edge_index=edge_index, y=y)

# 2. Define model
class GNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(in_features, 64)
        self.conv2 = GCNConv(64, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 3. Train
model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[train_mask], data.y[train_mask])
    loss.backward()
    optimizer.step()

Edge Index Format

COO (Coordinate) format:

# Edge list: (0→1), (1→2), (2→0)
edge_index = torch.tensor([
    [0, 1, 2],  # Source nodes
    [1, 2, 0]   # Target nodes
])

# For undirected graph, include both directions:
edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 0],  # Source
    [1, 0, 2, 1, 0, 2]   # Target
])

Mini-batching Graphs

Problem: Graphs have different sizes

Solution: Batch graphs as one large disconnected graph

from torch_geometric.data import DataLoader

# Create dataset
dataset = [Data(...), Data(...), ...]

# DataLoader handles batching
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    # batch contains multiple graphs as one large graph
    # batch.batch: Indicator which nodes belong to which graph
    out = model(batch.x, batch.edge_index)

Part 6: Common Mistakes

Mistake 1: LSTM Aggregation

Symptom: Different outputs for same graph with reordered nodes Fix: Use mean/sum/max aggregation (permutation invariant)

Mistake 2: Forgetting Edge Direction

Symptom: Information flows wrong way Fix: For undirected graphs, add edges in both directions

Mistake 3: Too Many Layers

Symptom: Performance degrades, over-smoothing Fix: Use 2-3 layers (most graphs have small diameter) Explanation: Too many layers → all nodes converge to same representation

Mistake 4: Not Testing Simple Baseline

Symptom: Complex GNN with minimal improvement Fix: Always test MLP on node features first

Mistake 5: Using GNN on Euclidean Data

Symptom: CNN/RNN would work better Fix: Use GNN only for irregular graph structure (not grids/sequences)

Part 7: Summary

Quick Reference

When to use GNN:

  • Graph-structured data (molecules, social networks, citations)
  • Irregular neighborhoods (not grid/sequence)
  • Graph structure informative (test this!)

Architecture selection:

Start: GCN (simplest)
Large graph: GraphSAGE (scalable)
Inductive learning: GraphSAGE or GAT
Neighbor importance: GAT (attention)

Key principles:

  • Message passing: Aggregate neighbors + Update node
  • Permutation invariance: Use mean/sum/max (not LSTM)
  • Test baseline: MLP first, GNN if structure helps
  • Layers: 2-3 sufficient (more = over-smoothing)

Implementation:

  • PyTorch Geometric: Standard library
  • COO format: Edge index as 2×E tensor
  • Batching: Merge graphs into one large graph

Next Steps

After mastering this skill:

  • transformer-architecture-deepdive: Understand attention (used in GAT)
  • architecture-design-principles: Design principles for graph architectures
  • Advanced GNNs: Graph Transformers, Equivariant GNNs

Remember: Not all graph data needs GNN. Test if graph structure actually helps! (Compare with MLP baseline)