# Graph Neural Networks Basics ## When to Use This Skill Use this skill when you need to: - ✅ Work with graph-structured data (molecules, social networks, citations) - ✅ Understand why CNN/RNN don't work on graphs - ✅ Learn message passing framework - ✅ Choose between GCN, GraphSAGE, GAT - ✅ Decide if GNN is appropriate (vs simple model) - ✅ Implement permutation-invariant aggregations **Do NOT use this skill for:** - ❌ Sequential data (use RNN/Transformer) - ❌ Grid data (use CNN) - ❌ High-level architecture selection (use `using-neural-architectures`) ## Core Principle **Graphs have irregular structure.** CNN (grid) and RNN (sequence) don't work. **GNN solution:** Message passing - Nodes aggregate information from neighbors - Multiple layers = multi-hop neighborhoods - Permutation invariant (order doesn't matter) **Critical question:** Does graph structure actually help? (Test: Compare with/without edges) ## Part 1: Why GNN (Not CNN/RNN) ### Problem: Graph Structure **Graph components:** - **Nodes**: Entities (atoms, users, papers) - **Edges**: Relationships (bonds, friendships, citations) - **Features**: Node/edge attributes **Key property:** Irregular structure - Each node has variable number of neighbors - No fixed spatial arrangement - Permutation invariant (node order doesn't matter) ### Why CNN Doesn't Work **CNN assumption:** Regular grid structure **Example:** Image (2D grid) ``` Every pixel has exactly 8 neighbors: [■][■][■] [■][X][■] ← Center pixel has 8 neighbors (fixed!) [■][■][■] CNN kernel: 3×3 (fixed size, fixed positions) ``` **Graph reality:** Irregular neighborhoods ``` Node A: 2 neighbors (H, C) Node B: 4 neighbors (C, C, C, H) Node C: 1 neighbor (H) No fixed kernel size or position! ``` **CNN limitations:** - Requires fixed-size neighborhoods → Graphs have variable-size - Assumes spatial locality → Graphs have arbitrary connectivity - Depends on node ordering → Should be permutation invariant ### Why RNN Doesn't Work **RNN assumption:** Sequential structure **Example:** Text (1D sequence) ``` "The cat sat" → [The] → [cat] → [sat] Clear sequential order, temporal dependencies ``` **Graph reality:** No inherent sequence ``` Social network: A — B — C | | D ——————E What's the "sequence"? A→B→C? A→D→E? No natural ordering! ``` **RNN limitations:** - Requires sequential order → Graphs have no natural order - Processes one element at a time → Graphs have parallel connections - Order-dependent → Should be permutation invariant ### GNN Solution **Key innovation:** Message passing on graph structure - Operate directly on nodes and edges - Variable-size neighborhoods (handled naturally) - Permutation invariant aggregations ## Part 2: Message Passing Framework ### Core Mechanism **Message passing in 3 steps:** **1. Aggregate neighbor messages** ```python # Node i aggregates from neighbors N(i) messages = [h_j for j in neighbors(i)] aggregated = aggregate(messages) # e.g., mean, sum, max ``` **2. Update node representation** ```python # Combine own features with aggregated messages h_i_new = update(h_i_old, aggregated) # e.g., neural network ``` **3. Repeat for L layers** - Layer 1: Node sees 1-hop neighbors - Layer 2: Node sees 2-hop neighbors - Layer L: Node sees L-hop neighborhood ### Concrete Example: Social Network **Task:** Predict user interests **Graph:** ``` B (sports) | A ---+--- C (cooking) | D (music) ``` **Layer 1: 1-hop neighbors** ```python # Node A aggregates from direct friends h_A_layer1 = update( h_A, aggregate([h_B, h_C, h_D]) ) # Now h_A includes friend interests ``` **Layer 2: 2-hop neighbors (friends of friends)** ```python # B's friends: E, F # C's friends: G, H # D's friends: I h_A_layer2 = update( h_A_layer1, aggregate([h_B', h_C', h_D']) # h_B' includes E, F ) # Now h_A includes friends-of-friends! ``` **Key insight:** More layers = larger receptive field (L-hop neighborhood) ### Permutation Invariance **Critical property:** Same graph → same output (regardless of node ordering) **Example:** ```python Graph: A-B, B-C Node list 1: [A, B, C] Node list 2: [C, B, A] Output MUST be identical! (Same graph, different ordering) ``` **Invariant aggregations:** - ✅ Mean: `mean([1, 2, 3]) == mean([3, 2, 1])` - ✅ Sum: `sum([1, 2, 3]) == sum([3, 2, 1])` - ✅ Max: `max([1, 2, 3]) == max([3, 2, 1])` **NOT invariant:** - ❌ LSTM: `LSTM([1, 2, 3]) != LSTM([3, 2, 1])` - ❌ Concatenate: `[1, 2, 3] != [3, 2, 1]` **Implementation:** ```python # CORRECT: Permutation invariant def aggregate(neighbor_features): return torch.mean(neighbor_features, dim=0) # WRONG: Order-dependent! def aggregate(neighbor_features): return LSTM(neighbor_features) # Output depends on order ``` ## Part 3: GNN Architectures ### Architecture 1: GCN (Graph Convolutional Network) **Key idea:** Spectral convolution on graphs (simplified) **Formula:** ```python h_i^(l+1) = σ(∑_{j∈N(i)} W^(l) h_j^(l) / √(|N(i)| |N(j)|)) # Normalize by degree (√(deg(i) * deg(j))) ``` **Aggregation:** Weighted mean (degree-normalized) **Properties:** - Transductive (needs full graph at training) - Computationally efficient - Good baseline **When to use:** - Full graph available at training time - Starting point (simplest GNN) - Small to medium graphs **Implementation:** ```python from torch_geometric.nn import GCNConv class GCN(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def forward(self, x, edge_index): # x: Node features (N, in_channels) # edge_index: Graph connectivity (2, E) # Layer 1 x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, p=0.5, training=self.training) # Layer 2 x = self.conv2(x, edge_index) return x ``` ### Architecture 2: GraphSAGE **Key idea:** Sample and aggregate (inductive learning) **Formula:** ```python # Sample fixed-size neighborhood neighbors_sampled = sample(neighbors(i), k=10) # Aggregate h_N = aggregate({h_j for j in neighbors_sampled}) # Concatenate and transform h_i^(l+1) = σ(W^(l) [h_i^(l); h_N]) ``` **Aggregation:** Mean, max, or LSTM (but mean/max preferred for invariance) **Key innovation:** Sampling - Sample fixed number of neighbors (e.g., 10) - Makes computation tractable for large graphs - Enables inductive learning (generalizes to unseen nodes) **When to use:** - Large graphs (millions of nodes) - Need inductive capability (new nodes appear) - Training on subset, testing on full graph **Implementation:** ```python from torch_geometric.nn import SAGEConv class GraphSAGE(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv(in_channels, hidden_channels) self.conv2 = SAGEConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index) return x ``` ### Architecture 3: GAT (Graph Attention Network) **Key idea:** Learn attention weights for neighbors **Formula:** ```python # Attention scores α_ij = attention(h_i, h_j) # How important is neighbor j to node i? # Normalize (softmax) α_ij = softmax_j(α_ij) # Weighted aggregation h_i^(l+1) = σ(∑_{j∈N(i)} α_ij W h_j^(l)) ``` **Key innovation:** Learned neighbor importance - Not all neighbors equally important - Attention mechanism decides weights - Multi-head attention (like Transformer) **When to use:** - Neighbors have varying importance - Need interpretability (attention weights) - Have sufficient data (attention needs more data to learn) **Implementation:** ```python from torch_geometric.nn import GATConv class GAT(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, heads=8): super().__init__() self.conv1 = GATConv(in_channels, hidden_channels, heads=heads) self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = F.elu(x) x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x, edge_index) return x ``` ### Architecture Comparison | Feature | GCN | GraphSAGE | GAT | |---------|-----|-----------|-----| | Aggregation | Degree-weighted mean | Mean/max/LSTM | Attention-weighted | | Neighbor weighting | Fixed (by degree) | Equal | Learned | | Inductive | No | Yes | Yes | | Scalability | Medium | High (sampling) | Medium | | Interpretability | Low | Low | High (attention) | | Complexity | Low | Medium | High | ### Decision Tree ``` Starting out / Small graph: → GCN (simplest baseline) Large graph (millions of nodes): → GraphSAGE (sampling enables scalability) Need inductive learning (new nodes): → GraphSAGE or GAT Neighbors have different importance: → GAT (attention learns importance) Need interpretability: → GAT (attention weights explain predictions) Production deployment: → GraphSAGE (most robust and scalable) ``` ## Part 4: When NOT to Use GNN ### Critical Question **Does graph structure actually help?** **Test:** Compare model with and without edges ```python # Baseline: MLP on node features only mlp_accuracy = train_mlp(node_features, labels) # GNN: Use node features + graph structure gnn_accuracy = train_gnn(node_features, edges, labels) # Decision: if gnn_accuracy - mlp_accuracy < 2%: print("Graph structure doesn't help much") print("Use simpler model (MLP or XGBoost)") else: print("Graph structure adds value") print("Use GNN") ``` ### Scenarios Where GNN Doesn't Help **1. Node features dominate** ``` User churn prediction: - Node features: Usage hours, demographics, subscription → Highly predictive - Graph edges: Sparse user interactions → Weak signal - Result: MLP 85%, GNN 86% (not worth complexity!) ``` **2. Sparse graphs** ``` Graph with 1000 nodes, 100 edges (0.01% density): - Most nodes have 0-1 neighbors - No information to aggregate - GNN reduces to MLP ``` **3. Random graph structure** ``` If edges are random (no homophily): - Neighbor labels uncorrelated - Aggregation adds noise - Simple model better ``` ### When GNN DOES Help ✅ **Molecular property prediction** - Structure is PRIMARY signal - Atom types + bonds determine properties - GNN: Huge improvement over fingerprints ✅ **Citation networks** - Paper quality correlated with neighbors - "You are what you cite" - Clear homophily ✅ **Social recommendation** - Friends have similar preferences - Graph structure informative - GNN: Moderate to large improvement ✅ **Knowledge graphs** - Entities connected by relations - Multi-hop reasoning valuable - GNN captures complex patterns ### Decision Framework ``` 1. Start simple: - Try MLP or XGBoost on node features - Establish baseline performance 2. Check graph structure value: - Does edge information correlate with target? - Is there homophily (similar nodes connected)? - Test: Remove edges, compare performance 3. Use GNN if: - Graph structure adds >2-5% accuracy - Structure is interpretable (not random) - Have enough nodes for GNN to learn 4. Stick with simple if: - Node features alone sufficient - Graph structure weak/random - Small dataset (< 1000 nodes) ``` ## Part 5: Practical Implementation ### Using PyTorch Geometric **Installation:** ```bash pip install torch-geometric ``` **Basic workflow:** ```python import torch from torch_geometric.data import Data from torch_geometric.nn import GCNConv # 1. Create graph data x = torch.tensor([[feature1], [feature2], ...]) # Node features edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) # Edges (COO format) y = torch.tensor([label1, label2, ...]) # Node labels data = Data(x=x, edge_index=edge_index, y=y) # 2. Define model class GNN(nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(in_features, 64) self.conv2 = GCNConv(64, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) # 3. Train model = GNN() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for epoch in range(200): model.train() optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[train_mask], data.y[train_mask]) loss.backward() optimizer.step() ``` ### Edge Index Format **COO (Coordinate) format:** ```python # Edge list: (0→1), (1→2), (2→0) edge_index = torch.tensor([ [0, 1, 2], # Source nodes [1, 2, 0] # Target nodes ]) # For undirected graph, include both directions: edge_index = torch.tensor([ [0, 1, 1, 2, 2, 0], # Source [1, 0, 2, 1, 0, 2] # Target ]) ``` ### Mini-batching Graphs **Problem:** Graphs have different sizes **Solution:** Batch graphs as one large disconnected graph ```python from torch_geometric.data import DataLoader # Create dataset dataset = [Data(...), Data(...), ...] # DataLoader handles batching loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: # batch contains multiple graphs as one large graph # batch.batch: Indicator which nodes belong to which graph out = model(batch.x, batch.edge_index) ``` ## Part 6: Common Mistakes ### Mistake 1: LSTM Aggregation **Symptom:** Different outputs for same graph with reordered nodes **Fix:** Use mean/sum/max aggregation (permutation invariant) ### Mistake 2: Forgetting Edge Direction **Symptom:** Information flows wrong way **Fix:** For undirected graphs, add edges in both directions ### Mistake 3: Too Many Layers **Symptom:** Performance degrades, over-smoothing **Fix:** Use 2-3 layers (most graphs have small diameter) **Explanation:** Too many layers → all nodes converge to same representation ### Mistake 4: Not Testing Simple Baseline **Symptom:** Complex GNN with minimal improvement **Fix:** Always test MLP on node features first ### Mistake 5: Using GNN on Euclidean Data **Symptom:** CNN/RNN would work better **Fix:** Use GNN only for irregular graph structure (not grids/sequences) ## Part 7: Summary ### Quick Reference **When to use GNN:** - Graph-structured data (molecules, social networks, citations) - Irregular neighborhoods (not grid/sequence) - Graph structure informative (test this!) **Architecture selection:** ``` Start: GCN (simplest) Large graph: GraphSAGE (scalable) Inductive learning: GraphSAGE or GAT Neighbor importance: GAT (attention) ``` **Key principles:** - Message passing: Aggregate neighbors + Update node - Permutation invariance: Use mean/sum/max (not LSTM) - Test baseline: MLP first, GNN if structure helps - Layers: 2-3 sufficient (more = over-smoothing) **Implementation:** - PyTorch Geometric: Standard library - COO format: Edge index as 2×E tensor - Batching: Merge graphs into one large graph ## Next Steps After mastering this skill: - `transformer-architecture-deepdive`: Understand attention (used in GAT) - `architecture-design-principles`: Design principles for graph architectures - Advanced GNNs: Graph Transformers, Equivariant GNNs **Remember:** Not all graph data needs GNN. Test if graph structure actually helps! (Compare with MLP baseline)