Files
gh-tachyon-beep-skillpacks-…/skills/using-neural-architectures/graph-neural-networks-basics.md
2025-11-30 09:00:00 +08:00

626 lines
15 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Graph Neural Networks Basics
## When to Use This Skill
Use this skill when you need to:
- ✅ Work with graph-structured data (molecules, social networks, citations)
- ✅ Understand why CNN/RNN don't work on graphs
- ✅ Learn message passing framework
- ✅ Choose between GCN, GraphSAGE, GAT
- ✅ Decide if GNN is appropriate (vs simple model)
- ✅ Implement permutation-invariant aggregations
**Do NOT use this skill for:**
- ❌ Sequential data (use RNN/Transformer)
- ❌ Grid data (use CNN)
- ❌ High-level architecture selection (use `using-neural-architectures`)
## Core Principle
**Graphs have irregular structure.** CNN (grid) and RNN (sequence) don't work.
**GNN solution:** Message passing
- Nodes aggregate information from neighbors
- Multiple layers = multi-hop neighborhoods
- Permutation invariant (order doesn't matter)
**Critical question:** Does graph structure actually help? (Test: Compare with/without edges)
## Part 1: Why GNN (Not CNN/RNN)
### Problem: Graph Structure
**Graph components:**
- **Nodes**: Entities (atoms, users, papers)
- **Edges**: Relationships (bonds, friendships, citations)
- **Features**: Node/edge attributes
**Key property:** Irregular structure
- Each node has variable number of neighbors
- No fixed spatial arrangement
- Permutation invariant (node order doesn't matter)
### Why CNN Doesn't Work
**CNN assumption:** Regular grid structure
**Example:** Image (2D grid)
```
Every pixel has exactly 8 neighbors:
[■][■][■]
[■][X][■] ← Center pixel has 8 neighbors (fixed!)
[■][■][■]
CNN kernel: 3×3 (fixed size, fixed positions)
```
**Graph reality:** Irregular neighborhoods
```
Node A: 2 neighbors (H, C)
Node B: 4 neighbors (C, C, C, H)
Node C: 1 neighbor (H)
No fixed kernel size or position!
```
**CNN limitations:**
- Requires fixed-size neighborhoods → Graphs have variable-size
- Assumes spatial locality → Graphs have arbitrary connectivity
- Depends on node ordering → Should be permutation invariant
### Why RNN Doesn't Work
**RNN assumption:** Sequential structure
**Example:** Text (1D sequence)
```
"The cat sat" → [The] → [cat] → [sat]
Clear sequential order, temporal dependencies
```
**Graph reality:** No inherent sequence
```
Social network:
A — B — C
| |
D ——————E
What's the "sequence"? A→B→C? A→D→E? No natural ordering!
```
**RNN limitations:**
- Requires sequential order → Graphs have no natural order
- Processes one element at a time → Graphs have parallel connections
- Order-dependent → Should be permutation invariant
### GNN Solution
**Key innovation:** Message passing on graph structure
- Operate directly on nodes and edges
- Variable-size neighborhoods (handled naturally)
- Permutation invariant aggregations
## Part 2: Message Passing Framework
### Core Mechanism
**Message passing in 3 steps:**
**1. Aggregate neighbor messages**
```python
# Node i aggregates from neighbors N(i)
messages = [h_j for j in neighbors(i)]
aggregated = aggregate(messages) # e.g., mean, sum, max
```
**2. Update node representation**
```python
# Combine own features with aggregated messages
h_i_new = update(h_i_old, aggregated) # e.g., neural network
```
**3. Repeat for L layers**
- Layer 1: Node sees 1-hop neighbors
- Layer 2: Node sees 2-hop neighbors
- Layer L: Node sees L-hop neighborhood
### Concrete Example: Social Network
**Task:** Predict user interests
**Graph:**
```
B (sports)
|
A ---+--- C (cooking)
|
D (music)
```
**Layer 1: 1-hop neighbors**
```python
# Node A aggregates from direct friends
h_A_layer1 = update(
h_A,
aggregate([h_B, h_C, h_D])
)
# Now h_A includes friend interests
```
**Layer 2: 2-hop neighbors (friends of friends)**
```python
# B's friends: E, F
# C's friends: G, H
# D's friends: I
h_A_layer2 = update(
h_A_layer1,
aggregate([h_B', h_C', h_D']) # h_B' includes E, F
)
# Now h_A includes friends-of-friends!
```
**Key insight:** More layers = larger receptive field (L-hop neighborhood)
### Permutation Invariance
**Critical property:** Same graph → same output (regardless of node ordering)
**Example:**
```python
Graph: A-B, B-C
Node list 1: [A, B, C]
Node list 2: [C, B, A]
Output MUST be identical! (Same graph, different ordering)
```
**Invariant aggregations:**
- ✅ Mean: `mean([1, 2, 3]) == mean([3, 2, 1])`
- ✅ Sum: `sum([1, 2, 3]) == sum([3, 2, 1])`
- ✅ Max: `max([1, 2, 3]) == max([3, 2, 1])`
**NOT invariant:**
- ❌ LSTM: `LSTM([1, 2, 3]) != LSTM([3, 2, 1])`
- ❌ Concatenate: `[1, 2, 3] != [3, 2, 1]`
**Implementation:**
```python
# CORRECT: Permutation invariant
def aggregate(neighbor_features):
return torch.mean(neighbor_features, dim=0)
# WRONG: Order-dependent!
def aggregate(neighbor_features):
return LSTM(neighbor_features) # Output depends on order
```
## Part 3: GNN Architectures
### Architecture 1: GCN (Graph Convolutional Network)
**Key idea:** Spectral convolution on graphs (simplified)
**Formula:**
```python
h_i^(l+1) = σ(_{jN(i)} W^(l) h_j^(l) / (|N(i)| |N(j)|))
# Normalize by degree (√(deg(i) * deg(j)))
```
**Aggregation:** Weighted mean (degree-normalized)
**Properties:**
- Transductive (needs full graph at training)
- Computationally efficient
- Good baseline
**When to use:**
- Full graph available at training time
- Starting point (simplest GNN)
- Small to medium graphs
**Implementation:**
```python
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# x: Node features (N, in_channels)
# edge_index: Graph connectivity (2, E)
# Layer 1
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
# Layer 2
x = self.conv2(x, edge_index)
return x
```
### Architecture 2: GraphSAGE
**Key idea:** Sample and aggregate (inductive learning)
**Formula:**
```python
# Sample fixed-size neighborhood
neighbors_sampled = sample(neighbors(i), k=10)
# Aggregate
h_N = aggregate({h_j for j in neighbors_sampled})
# Concatenate and transform
h_i^(l+1) = σ(W^(l) [h_i^(l); h_N])
```
**Aggregation:** Mean, max, or LSTM (but mean/max preferred for invariance)
**Key innovation:** Sampling
- Sample fixed number of neighbors (e.g., 10)
- Makes computation tractable for large graphs
- Enables inductive learning (generalizes to unseen nodes)
**When to use:**
- Large graphs (millions of nodes)
- Need inductive capability (new nodes appear)
- Training on subset, testing on full graph
**Implementation:**
```python
from torch_geometric.nn import SAGEConv
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
```
### Architecture 3: GAT (Graph Attention Network)
**Key idea:** Learn attention weights for neighbors
**Formula:**
```python
# Attention scores
α_ij = attention(h_i, h_j) # How important is neighbor j to node i?
# Normalize (softmax)
α_ij = softmax_j(α_ij)
# Weighted aggregation
h_i^(l+1) = σ(_{jN(i)} α_ij W h_j^(l))
```
**Key innovation:** Learned neighbor importance
- Not all neighbors equally important
- Attention mechanism decides weights
- Multi-head attention (like Transformer)
**When to use:**
- Neighbors have varying importance
- Need interpretability (attention weights)
- Have sufficient data (attention needs more data to learn)
**Implementation:**
```python
from torch_geometric.nn import GATConv
class GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
```
### Architecture Comparison
| Feature | GCN | GraphSAGE | GAT |
|---------|-----|-----------|-----|
| Aggregation | Degree-weighted mean | Mean/max/LSTM | Attention-weighted |
| Neighbor weighting | Fixed (by degree) | Equal | Learned |
| Inductive | No | Yes | Yes |
| Scalability | Medium | High (sampling) | Medium |
| Interpretability | Low | Low | High (attention) |
| Complexity | Low | Medium | High |
### Decision Tree
```
Starting out / Small graph:
→ GCN (simplest baseline)
Large graph (millions of nodes):
→ GraphSAGE (sampling enables scalability)
Need inductive learning (new nodes):
→ GraphSAGE or GAT
Neighbors have different importance:
→ GAT (attention learns importance)
Need interpretability:
→ GAT (attention weights explain predictions)
Production deployment:
→ GraphSAGE (most robust and scalable)
```
## Part 4: When NOT to Use GNN
### Critical Question
**Does graph structure actually help?**
**Test:** Compare model with and without edges
```python
# Baseline: MLP on node features only
mlp_accuracy = train_mlp(node_features, labels)
# GNN: Use node features + graph structure
gnn_accuracy = train_gnn(node_features, edges, labels)
# Decision:
if gnn_accuracy - mlp_accuracy < 2%:
print("Graph structure doesn't help much")
print("Use simpler model (MLP or XGBoost)")
else:
print("Graph structure adds value")
print("Use GNN")
```
### Scenarios Where GNN Doesn't Help
**1. Node features dominate**
```
User churn prediction:
- Node features: Usage hours, demographics, subscription → Highly predictive
- Graph edges: Sparse user interactions → Weak signal
- Result: MLP 85%, GNN 86% (not worth complexity!)
```
**2. Sparse graphs**
```
Graph with 1000 nodes, 100 edges (0.01% density):
- Most nodes have 0-1 neighbors
- No information to aggregate
- GNN reduces to MLP
```
**3. Random graph structure**
```
If edges are random (no homophily):
- Neighbor labels uncorrelated
- Aggregation adds noise
- Simple model better
```
### When GNN DOES Help
**Molecular property prediction**
- Structure is PRIMARY signal
- Atom types + bonds determine properties
- GNN: Huge improvement over fingerprints
**Citation networks**
- Paper quality correlated with neighbors
- "You are what you cite"
- Clear homophily
**Social recommendation**
- Friends have similar preferences
- Graph structure informative
- GNN: Moderate to large improvement
**Knowledge graphs**
- Entities connected by relations
- Multi-hop reasoning valuable
- GNN captures complex patterns
### Decision Framework
```
1. Start simple:
- Try MLP or XGBoost on node features
- Establish baseline performance
2. Check graph structure value:
- Does edge information correlate with target?
- Is there homophily (similar nodes connected)?
- Test: Remove edges, compare performance
3. Use GNN if:
- Graph structure adds >2-5% accuracy
- Structure is interpretable (not random)
- Have enough nodes for GNN to learn
4. Stick with simple if:
- Node features alone sufficient
- Graph structure weak/random
- Small dataset (< 1000 nodes)
```
## Part 5: Practical Implementation
### Using PyTorch Geometric
**Installation:**
```bash
pip install torch-geometric
```
**Basic workflow:**
```python
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
# 1. Create graph data
x = torch.tensor([[feature1], [feature2], ...]) # Node features
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) # Edges (COO format)
y = torch.tensor([label1, label2, ...]) # Node labels
data = Data(x=x, edge_index=edge_index, y=y)
# 2. Define model
class GNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(in_features, 64)
self.conv2 = GCNConv(64, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# 3. Train
model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[train_mask], data.y[train_mask])
loss.backward()
optimizer.step()
```
### Edge Index Format
**COO (Coordinate) format:**
```python
# Edge list: (0→1), (1→2), (2→0)
edge_index = torch.tensor([
[0, 1, 2], # Source nodes
[1, 2, 0] # Target nodes
])
# For undirected graph, include both directions:
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 0], # Source
[1, 0, 2, 1, 0, 2] # Target
])
```
### Mini-batching Graphs
**Problem:** Graphs have different sizes
**Solution:** Batch graphs as one large disconnected graph
```python
from torch_geometric.data import DataLoader
# Create dataset
dataset = [Data(...), Data(...), ...]
# DataLoader handles batching
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
# batch contains multiple graphs as one large graph
# batch.batch: Indicator which nodes belong to which graph
out = model(batch.x, batch.edge_index)
```
## Part 6: Common Mistakes
### Mistake 1: LSTM Aggregation
**Symptom:** Different outputs for same graph with reordered nodes
**Fix:** Use mean/sum/max aggregation (permutation invariant)
### Mistake 2: Forgetting Edge Direction
**Symptom:** Information flows wrong way
**Fix:** For undirected graphs, add edges in both directions
### Mistake 3: Too Many Layers
**Symptom:** Performance degrades, over-smoothing
**Fix:** Use 2-3 layers (most graphs have small diameter)
**Explanation:** Too many layers → all nodes converge to same representation
### Mistake 4: Not Testing Simple Baseline
**Symptom:** Complex GNN with minimal improvement
**Fix:** Always test MLP on node features first
### Mistake 5: Using GNN on Euclidean Data
**Symptom:** CNN/RNN would work better
**Fix:** Use GNN only for irregular graph structure (not grids/sequences)
## Part 7: Summary
### Quick Reference
**When to use GNN:**
- Graph-structured data (molecules, social networks, citations)
- Irregular neighborhoods (not grid/sequence)
- Graph structure informative (test this!)
**Architecture selection:**
```
Start: GCN (simplest)
Large graph: GraphSAGE (scalable)
Inductive learning: GraphSAGE or GAT
Neighbor importance: GAT (attention)
```
**Key principles:**
- Message passing: Aggregate neighbors + Update node
- Permutation invariance: Use mean/sum/max (not LSTM)
- Test baseline: MLP first, GNN if structure helps
- Layers: 2-3 sufficient (more = over-smoothing)
**Implementation:**
- PyTorch Geometric: Standard library
- COO format: Edge index as 2×E tensor
- Batching: Merge graphs into one large graph
## Next Steps
After mastering this skill:
- `transformer-architecture-deepdive`: Understand attention (used in GAT)
- `architecture-design-principles`: Design principles for graph architectures
- Advanced GNNs: Graph Transformers, Equivariant GNNs
**Remember:** Not all graph data needs GNN. Test if graph structure actually helps! (Compare with MLP baseline)