Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 09:00:00 +08:00
commit 955d5c6743
12 changed files with 6996 additions and 0 deletions

View File

@@ -0,0 +1,625 @@
# Graph Neural Networks Basics
## When to Use This Skill
Use this skill when you need to:
- ✅ Work with graph-structured data (molecules, social networks, citations)
- ✅ Understand why CNN/RNN don't work on graphs
- ✅ Learn message passing framework
- ✅ Choose between GCN, GraphSAGE, GAT
- ✅ Decide if GNN is appropriate (vs simple model)
- ✅ Implement permutation-invariant aggregations
**Do NOT use this skill for:**
- ❌ Sequential data (use RNN/Transformer)
- ❌ Grid data (use CNN)
- ❌ High-level architecture selection (use `using-neural-architectures`)
## Core Principle
**Graphs have irregular structure.** CNN (grid) and RNN (sequence) don't work.
**GNN solution:** Message passing
- Nodes aggregate information from neighbors
- Multiple layers = multi-hop neighborhoods
- Permutation invariant (order doesn't matter)
**Critical question:** Does graph structure actually help? (Test: Compare with/without edges)
## Part 1: Why GNN (Not CNN/RNN)
### Problem: Graph Structure
**Graph components:**
- **Nodes**: Entities (atoms, users, papers)
- **Edges**: Relationships (bonds, friendships, citations)
- **Features**: Node/edge attributes
**Key property:** Irregular structure
- Each node has variable number of neighbors
- No fixed spatial arrangement
- Permutation invariant (node order doesn't matter)
### Why CNN Doesn't Work
**CNN assumption:** Regular grid structure
**Example:** Image (2D grid)
```
Every pixel has exactly 8 neighbors:
[■][■][■]
[■][X][■] ← Center pixel has 8 neighbors (fixed!)
[■][■][■]
CNN kernel: 3×3 (fixed size, fixed positions)
```
**Graph reality:** Irregular neighborhoods
```
Node A: 2 neighbors (H, C)
Node B: 4 neighbors (C, C, C, H)
Node C: 1 neighbor (H)
No fixed kernel size or position!
```
**CNN limitations:**
- Requires fixed-size neighborhoods → Graphs have variable-size
- Assumes spatial locality → Graphs have arbitrary connectivity
- Depends on node ordering → Should be permutation invariant
### Why RNN Doesn't Work
**RNN assumption:** Sequential structure
**Example:** Text (1D sequence)
```
"The cat sat" → [The] → [cat] → [sat]
Clear sequential order, temporal dependencies
```
**Graph reality:** No inherent sequence
```
Social network:
A — B — C
| |
D ——————E
What's the "sequence"? A→B→C? A→D→E? No natural ordering!
```
**RNN limitations:**
- Requires sequential order → Graphs have no natural order
- Processes one element at a time → Graphs have parallel connections
- Order-dependent → Should be permutation invariant
### GNN Solution
**Key innovation:** Message passing on graph structure
- Operate directly on nodes and edges
- Variable-size neighborhoods (handled naturally)
- Permutation invariant aggregations
## Part 2: Message Passing Framework
### Core Mechanism
**Message passing in 3 steps:**
**1. Aggregate neighbor messages**
```python
# Node i aggregates from neighbors N(i)
messages = [h_j for j in neighbors(i)]
aggregated = aggregate(messages) # e.g., mean, sum, max
```
**2. Update node representation**
```python
# Combine own features with aggregated messages
h_i_new = update(h_i_old, aggregated) # e.g., neural network
```
**3. Repeat for L layers**
- Layer 1: Node sees 1-hop neighbors
- Layer 2: Node sees 2-hop neighbors
- Layer L: Node sees L-hop neighborhood
### Concrete Example: Social Network
**Task:** Predict user interests
**Graph:**
```
B (sports)
|
A ---+--- C (cooking)
|
D (music)
```
**Layer 1: 1-hop neighbors**
```python
# Node A aggregates from direct friends
h_A_layer1 = update(
h_A,
aggregate([h_B, h_C, h_D])
)
# Now h_A includes friend interests
```
**Layer 2: 2-hop neighbors (friends of friends)**
```python
# B's friends: E, F
# C's friends: G, H
# D's friends: I
h_A_layer2 = update(
h_A_layer1,
aggregate([h_B', h_C', h_D']) # h_B' includes E, F
)
# Now h_A includes friends-of-friends!
```
**Key insight:** More layers = larger receptive field (L-hop neighborhood)
### Permutation Invariance
**Critical property:** Same graph → same output (regardless of node ordering)
**Example:**
```python
Graph: A-B, B-C
Node list 1: [A, B, C]
Node list 2: [C, B, A]
Output MUST be identical! (Same graph, different ordering)
```
**Invariant aggregations:**
- ✅ Mean: `mean([1, 2, 3]) == mean([3, 2, 1])`
- ✅ Sum: `sum([1, 2, 3]) == sum([3, 2, 1])`
- ✅ Max: `max([1, 2, 3]) == max([3, 2, 1])`
**NOT invariant:**
- ❌ LSTM: `LSTM([1, 2, 3]) != LSTM([3, 2, 1])`
- ❌ Concatenate: `[1, 2, 3] != [3, 2, 1]`
**Implementation:**
```python
# CORRECT: Permutation invariant
def aggregate(neighbor_features):
return torch.mean(neighbor_features, dim=0)
# WRONG: Order-dependent!
def aggregate(neighbor_features):
return LSTM(neighbor_features) # Output depends on order
```
## Part 3: GNN Architectures
### Architecture 1: GCN (Graph Convolutional Network)
**Key idea:** Spectral convolution on graphs (simplified)
**Formula:**
```python
h_i^(l+1) = σ(_{jN(i)} W^(l) h_j^(l) / (|N(i)| |N(j)|))
# Normalize by degree (√(deg(i) * deg(j)))
```
**Aggregation:** Weighted mean (degree-normalized)
**Properties:**
- Transductive (needs full graph at training)
- Computationally efficient
- Good baseline
**When to use:**
- Full graph available at training time
- Starting point (simplest GNN)
- Small to medium graphs
**Implementation:**
```python
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# x: Node features (N, in_channels)
# edge_index: Graph connectivity (2, E)
# Layer 1
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
# Layer 2
x = self.conv2(x, edge_index)
return x
```
### Architecture 2: GraphSAGE
**Key idea:** Sample and aggregate (inductive learning)
**Formula:**
```python
# Sample fixed-size neighborhood
neighbors_sampled = sample(neighbors(i), k=10)
# Aggregate
h_N = aggregate({h_j for j in neighbors_sampled})
# Concatenate and transform
h_i^(l+1) = σ(W^(l) [h_i^(l); h_N])
```
**Aggregation:** Mean, max, or LSTM (but mean/max preferred for invariance)
**Key innovation:** Sampling
- Sample fixed number of neighbors (e.g., 10)
- Makes computation tractable for large graphs
- Enables inductive learning (generalizes to unseen nodes)
**When to use:**
- Large graphs (millions of nodes)
- Need inductive capability (new nodes appear)
- Training on subset, testing on full graph
**Implementation:**
```python
from torch_geometric.nn import SAGEConv
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
```
### Architecture 3: GAT (Graph Attention Network)
**Key idea:** Learn attention weights for neighbors
**Formula:**
```python
# Attention scores
α_ij = attention(h_i, h_j) # How important is neighbor j to node i?
# Normalize (softmax)
α_ij = softmax_j(α_ij)
# Weighted aggregation
h_i^(l+1) = σ(_{jN(i)} α_ij W h_j^(l))
```
**Key innovation:** Learned neighbor importance
- Not all neighbors equally important
- Attention mechanism decides weights
- Multi-head attention (like Transformer)
**When to use:**
- Neighbors have varying importance
- Need interpretability (attention weights)
- Have sufficient data (attention needs more data to learn)
**Implementation:**
```python
from torch_geometric.nn import GATConv
class GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
```
### Architecture Comparison
| Feature | GCN | GraphSAGE | GAT |
|---------|-----|-----------|-----|
| Aggregation | Degree-weighted mean | Mean/max/LSTM | Attention-weighted |
| Neighbor weighting | Fixed (by degree) | Equal | Learned |
| Inductive | No | Yes | Yes |
| Scalability | Medium | High (sampling) | Medium |
| Interpretability | Low | Low | High (attention) |
| Complexity | Low | Medium | High |
### Decision Tree
```
Starting out / Small graph:
→ GCN (simplest baseline)
Large graph (millions of nodes):
→ GraphSAGE (sampling enables scalability)
Need inductive learning (new nodes):
→ GraphSAGE or GAT
Neighbors have different importance:
→ GAT (attention learns importance)
Need interpretability:
→ GAT (attention weights explain predictions)
Production deployment:
→ GraphSAGE (most robust and scalable)
```
## Part 4: When NOT to Use GNN
### Critical Question
**Does graph structure actually help?**
**Test:** Compare model with and without edges
```python
# Baseline: MLP on node features only
mlp_accuracy = train_mlp(node_features, labels)
# GNN: Use node features + graph structure
gnn_accuracy = train_gnn(node_features, edges, labels)
# Decision:
if gnn_accuracy - mlp_accuracy < 2%:
print("Graph structure doesn't help much")
print("Use simpler model (MLP or XGBoost)")
else:
print("Graph structure adds value")
print("Use GNN")
```
### Scenarios Where GNN Doesn't Help
**1. Node features dominate**
```
User churn prediction:
- Node features: Usage hours, demographics, subscription → Highly predictive
- Graph edges: Sparse user interactions → Weak signal
- Result: MLP 85%, GNN 86% (not worth complexity!)
```
**2. Sparse graphs**
```
Graph with 1000 nodes, 100 edges (0.01% density):
- Most nodes have 0-1 neighbors
- No information to aggregate
- GNN reduces to MLP
```
**3. Random graph structure**
```
If edges are random (no homophily):
- Neighbor labels uncorrelated
- Aggregation adds noise
- Simple model better
```
### When GNN DOES Help
**Molecular property prediction**
- Structure is PRIMARY signal
- Atom types + bonds determine properties
- GNN: Huge improvement over fingerprints
**Citation networks**
- Paper quality correlated with neighbors
- "You are what you cite"
- Clear homophily
**Social recommendation**
- Friends have similar preferences
- Graph structure informative
- GNN: Moderate to large improvement
**Knowledge graphs**
- Entities connected by relations
- Multi-hop reasoning valuable
- GNN captures complex patterns
### Decision Framework
```
1. Start simple:
- Try MLP or XGBoost on node features
- Establish baseline performance
2. Check graph structure value:
- Does edge information correlate with target?
- Is there homophily (similar nodes connected)?
- Test: Remove edges, compare performance
3. Use GNN if:
- Graph structure adds >2-5% accuracy
- Structure is interpretable (not random)
- Have enough nodes for GNN to learn
4. Stick with simple if:
- Node features alone sufficient
- Graph structure weak/random
- Small dataset (< 1000 nodes)
```
## Part 5: Practical Implementation
### Using PyTorch Geometric
**Installation:**
```bash
pip install torch-geometric
```
**Basic workflow:**
```python
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
# 1. Create graph data
x = torch.tensor([[feature1], [feature2], ...]) # Node features
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) # Edges (COO format)
y = torch.tensor([label1, label2, ...]) # Node labels
data = Data(x=x, edge_index=edge_index, y=y)
# 2. Define model
class GNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(in_features, 64)
self.conv2 = GCNConv(64, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# 3. Train
model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[train_mask], data.y[train_mask])
loss.backward()
optimizer.step()
```
### Edge Index Format
**COO (Coordinate) format:**
```python
# Edge list: (0→1), (1→2), (2→0)
edge_index = torch.tensor([
[0, 1, 2], # Source nodes
[1, 2, 0] # Target nodes
])
# For undirected graph, include both directions:
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 0], # Source
[1, 0, 2, 1, 0, 2] # Target
])
```
### Mini-batching Graphs
**Problem:** Graphs have different sizes
**Solution:** Batch graphs as one large disconnected graph
```python
from torch_geometric.data import DataLoader
# Create dataset
dataset = [Data(...), Data(...), ...]
# DataLoader handles batching
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
# batch contains multiple graphs as one large graph
# batch.batch: Indicator which nodes belong to which graph
out = model(batch.x, batch.edge_index)
```
## Part 6: Common Mistakes
### Mistake 1: LSTM Aggregation
**Symptom:** Different outputs for same graph with reordered nodes
**Fix:** Use mean/sum/max aggregation (permutation invariant)
### Mistake 2: Forgetting Edge Direction
**Symptom:** Information flows wrong way
**Fix:** For undirected graphs, add edges in both directions
### Mistake 3: Too Many Layers
**Symptom:** Performance degrades, over-smoothing
**Fix:** Use 2-3 layers (most graphs have small diameter)
**Explanation:** Too many layers → all nodes converge to same representation
### Mistake 4: Not Testing Simple Baseline
**Symptom:** Complex GNN with minimal improvement
**Fix:** Always test MLP on node features first
### Mistake 5: Using GNN on Euclidean Data
**Symptom:** CNN/RNN would work better
**Fix:** Use GNN only for irregular graph structure (not grids/sequences)
## Part 7: Summary
### Quick Reference
**When to use GNN:**
- Graph-structured data (molecules, social networks, citations)
- Irregular neighborhoods (not grid/sequence)
- Graph structure informative (test this!)
**Architecture selection:**
```
Start: GCN (simplest)
Large graph: GraphSAGE (scalable)
Inductive learning: GraphSAGE or GAT
Neighbor importance: GAT (attention)
```
**Key principles:**
- Message passing: Aggregate neighbors + Update node
- Permutation invariance: Use mean/sum/max (not LSTM)
- Test baseline: MLP first, GNN if structure helps
- Layers: 2-3 sufficient (more = over-smoothing)
**Implementation:**
- PyTorch Geometric: Standard library
- COO format: Edge index as 2×E tensor
- Batching: Merge graphs into one large graph
## Next Steps
After mastering this skill:
- `transformer-architecture-deepdive`: Understand attention (used in GAT)
- `architecture-design-principles`: Design principles for graph architectures
- Advanced GNNs: Graph Transformers, Equivariant GNNs
**Remember:** Not all graph data needs GNN. Test if graph structure actually helps! (Compare with MLP baseline)