Files
gh-k-dense-ai-claude-scient…/skills/torch_geometric/references/layers_reference.md
2025-11-30 08:30:10 +08:00

486 lines
17 KiB
Markdown

# PyTorch Geometric Neural Network Layers Reference
This document provides a comprehensive reference of all neural network layers available in `torch_geometric.nn`.
## Layer Capability Flags
When selecting layers, consider these capability flags:
- **SparseTensor**: Supports `torch_sparse.SparseTensor` format for efficient sparse operations
- **edge_weight**: Handles one-dimensional edge weight data
- **edge_attr**: Processes multi-dimensional edge feature information
- **Bipartite**: Works with bipartite graphs (different source/target node dimensions)
- **Static**: Operates on static graphs with batched node features
- **Lazy**: Enables initialization without specifying input channel dimensions
## Convolutional Layers
### Standard Graph Convolutions
**GCNConv** - Graph Convolutional Network layer
- Implements spectral graph convolution with symmetric normalization
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Citation networks, social networks, general graph learning
- Example: `GCNConv(in_channels, out_channels, improved=False, cached=True)`
**SAGEConv** - GraphSAGE layer
- Inductive learning via neighborhood sampling and aggregation
- Supports: SparseTensor, Bipartite, Lazy
- Use for: Large graphs, inductive learning, heterogeneous features
- Example: `SAGEConv(in_channels, out_channels, aggr='mean')`
**GATConv** - Graph Attention Network layer
- Multi-head attention mechanism for adaptive neighbor weighting
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
- Use for: Tasks requiring variable neighbor importance
- Example: `GATConv(in_channels, out_channels, heads=8, dropout=0.6)`
**GraphConv** - Simple graph convolution (Morris et al.)
- Basic message passing with optional edge weights
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Baseline models, simple graph structures
- Example: `GraphConv(in_channels, out_channels, aggr='add')`
**GINConv** - Graph Isomorphism Network layer
- Maximally powerful GNN for graph isomorphism testing
- Supports: Bipartite
- Use for: Graph classification, molecular property prediction
- Example: `GINConv(nn.Sequential(nn.Linear(in_channels, out_channels), nn.ReLU()))`
**TransformerConv** - Graph Transformer layer
- Combines graph structure with transformer attention
- Supports: SparseTensor, Bipartite, Lazy
- Use for: Long-range dependencies, complex graphs
- Example: `TransformerConv(in_channels, out_channels, heads=8, beta=True)`
**ChebConv** - Chebyshev spectral graph convolution
- Uses Chebyshev polynomials for efficient spectral filtering
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Spectral graph learning, efficient convolutions
- Example: `ChebConv(in_channels, out_channels, K=3)`
**SGConv** - Simplified Graph Convolution
- Pre-computes fixed number of propagation steps
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Fast training, shallow models
- Example: `SGConv(in_channels, out_channels, K=2)`
**APPNP** - Approximate Personalized Propagation of Neural Predictions
- Separates feature transformation from propagation
- Supports: SparseTensor, edge_weight, Lazy
- Use for: Deep propagation without oversmoothing
- Example: `APPNP(K=10, alpha=0.1)`
**ARMAConv** - ARMA graph convolution
- Uses ARMA filters for graph filtering
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Advanced spectral methods
- Example: `ARMAConv(in_channels, out_channels, num_stacks=3, num_layers=2)`
**GATv2Conv** - Improved Graph Attention Network
- Fixes static attention computation issue in GAT
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
- Use for: Better attention learning than original GAT
- Example: `GATv2Conv(in_channels, out_channels, heads=8)`
**SuperGATConv** - Self-supervised Graph Attention
- Adds self-supervised attention mechanism
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
- Use for: Self-supervised learning, limited labels
- Example: `SuperGATConv(in_channels, out_channels, heads=8)`
**GMMConv** - Gaussian Mixture Model Convolution
- Uses Gaussian kernels in pseudo-coordinate space
- Supports: Bipartite
- Use for: Point clouds, spatial data
- Example: `GMMConv(in_channels, out_channels, dim=3, kernel_size=5)`
**SplineConv** - Spline-based convolution
- B-spline basis functions for spatial filtering
- Supports: Bipartite
- Use for: Irregular grids, continuous spaces
- Example: `SplineConv(in_channels, out_channels, dim=2, kernel_size=5)`
**NNConv** - Neural Network Convolution
- Edge features processed by neural networks
- Supports: edge_attr, Bipartite
- Use for: Rich edge features, molecular graphs
- Example: `NNConv(in_channels, out_channels, nn=edge_nn, aggr='mean')`
**CGConv** - Crystal Graph Convolution
- Designed for crystalline materials
- Supports: Bipartite
- Use for: Materials science, crystal structures
- Example: `CGConv(in_channels, dim=3, batch_norm=True)`
**EdgeConv** - Edge Convolution (Dynamic Graph CNN)
- Dynamically computes edges based on feature space
- Supports: Static
- Use for: Point clouds, dynamic graphs
- Example: `EdgeConv(nn=edge_nn, aggr='max')`
**PointNetConv** - PointNet++ convolution
- Local and global feature learning for point clouds
- Use for: 3D point cloud processing
- Example: `PointNetConv(local_nn, global_nn)`
**ResGatedGraphConv** - Residual Gated Graph Convolution
- Gating mechanism with residual connections
- Supports: edge_attr, Bipartite, Lazy
- Use for: Deep GNNs, complex features
- Example: `ResGatedGraphConv(in_channels, out_channels)`
**GENConv** - Generalized Graph Convolution
- Generalizes multiple GNN variants
- Supports: SparseTensor, edge_weight, edge_attr, Bipartite, Lazy
- Use for: Flexible architecture exploration
- Example: `GENConv(in_channels, out_channels, aggr='softmax', num_layers=2)`
**FiLMConv** - Feature-wise Linear Modulation
- Conditions on global features
- Supports: Bipartite, Lazy
- Use for: Conditional generation, multi-task learning
- Example: `FiLMConv(in_channels, out_channels, num_relations=5)`
**PANConv** - Path Attention Network
- Attention over multi-hop paths
- Supports: SparseTensor, Lazy
- Use for: Complex connectivity patterns
- Example: `PANConv(in_channels, out_channels, filter_size=3)`
**ClusterGCNConv** - Cluster-GCN convolution
- Efficient training via graph clustering
- Supports: edge_attr, Lazy
- Use for: Very large graphs
- Example: `ClusterGCNConv(in_channels, out_channels)`
**MFConv** - Multi-scale Feature Convolution
- Aggregates features at multiple scales
- Supports: SparseTensor, Lazy
- Use for: Multi-scale patterns
- Example: `MFConv(in_channels, out_channels)`
**RGCNConv** - Relational Graph Convolution
- Handles multiple edge types
- Supports: SparseTensor, edge_weight, Lazy
- Use for: Knowledge graphs, heterogeneous graphs
- Example: `RGCNConv(in_channels, out_channels, num_relations=10)`
**FAConv** - Frequency Adaptive Convolution
- Adaptive filtering in spectral domain
- Supports: SparseTensor, Lazy
- Use for: Spectral graph learning
- Example: `FAConv(in_channels, eps=0.1, dropout=0.5)`
### Molecular and 3D Convolutions
**SchNet** - Continuous-filter convolutional layer
- Designed for molecular dynamics
- Use for: Molecular property prediction, 3D molecules
- Example: `SchNet(hidden_channels=128, num_filters=64, num_interactions=6)`
**DimeNet** - Directional Message Passing
- Uses directional information and angles
- Use for: 3D molecular structures, chemical properties
- Example: `DimeNet(hidden_channels=128, out_channels=1, num_blocks=6)`
**PointTransformerConv** - Point cloud transformer
- Transformer for 3D point clouds
- Use for: 3D vision, point cloud segmentation
- Example: `PointTransformerConv(in_channels, out_channels)`
### Hypergraph Convolutions
**HypergraphConv** - Hypergraph convolution
- Operates on hyperedges (edges connecting multiple nodes)
- Supports: Lazy
- Use for: Multi-way relationships, chemical reactions
- Example: `HypergraphConv(in_channels, out_channels)`
**HGTConv** - Heterogeneous Graph Transformer
- Transformer for heterogeneous graphs with multiple types
- Supports: Lazy
- Use for: Heterogeneous networks, knowledge graphs
- Example: `HGTConv(in_channels, out_channels, metadata, heads=8)`
## Aggregation Operators
**Aggr** - Base aggregation class
- Flexible aggregation across nodes
**SumAggregation** - Sum aggregation
- Example: `SumAggregation()`
**MeanAggregation** - Mean aggregation
- Example: `MeanAggregation()`
**MaxAggregation** - Max aggregation
- Example: `MaxAggregation()`
**SoftmaxAggregation** - Softmax-weighted aggregation
- Learnable attention weights
- Example: `SoftmaxAggregation(learn=True)`
**PowerMeanAggregation** - Power mean aggregation
- Learnable power parameter
- Example: `PowerMeanAggregation(learn=True)`
**LSTMAggregation** - LSTM-based aggregation
- Sequential processing of neighbors
- Example: `LSTMAggregation(in_channels, out_channels)`
**SetTransformerAggregation** - Set Transformer aggregation
- Transformer for permutation-invariant aggregation
- Example: `SetTransformerAggregation(in_channels, out_channels)`
**MultiAggregation** - Multiple aggregations
- Combines multiple aggregation methods
- Example: `MultiAggregation(['mean', 'max', 'std'])`
## Pooling Layers
### Global Pooling
**global_mean_pool** - Global mean pooling
- Averages node features per graph
- Example: `global_mean_pool(x, batch)`
**global_max_pool** - Global max pooling
- Max over node features per graph
- Example: `global_max_pool(x, batch)`
**global_add_pool** - Global sum pooling
- Sums node features per graph
- Example: `global_add_pool(x, batch)`
**global_sort_pool** - Global sort pooling
- Sorts and concatenates top-k nodes
- Example: `global_sort_pool(x, batch, k=30)`
**GlobalAttention** - Global attention pooling
- Learnable attention weights for aggregation
- Example: `GlobalAttention(gate_nn)`
**Set2Set** - Set2Set pooling
- LSTM-based attention mechanism
- Example: `Set2Set(in_channels, processing_steps=3)`
### Hierarchical Pooling
**TopKPooling** - Top-k pooling
- Keeps top-k nodes based on projection scores
- Example: `TopKPooling(in_channels, ratio=0.5)`
**SAGPooling** - Self-Attention Graph Pooling
- Uses self-attention for node selection
- Example: `SAGPooling(in_channels, ratio=0.5)`
**ASAPooling** - Adaptive Structure Aware Pooling
- Structure-aware node selection
- Example: `ASAPooling(in_channels, ratio=0.5)`
**PANPooling** - Path Attention Pooling
- Attention over paths for pooling
- Example: `PANPooling(in_channels, ratio=0.5)`
**EdgePooling** - Edge contraction pooling
- Pools by contracting edges
- Example: `EdgePooling(in_channels)`
**MemPooling** - Memory-based pooling
- Learnable cluster assignments
- Example: `MemPooling(in_channels, out_channels, heads=4, num_clusters=10)`
**avg_pool** / **max_pool** - Average/Max pool with clustering
- Pools nodes within clusters
- Example: `avg_pool(cluster, data)`
## Normalization Layers
**BatchNorm** - Batch normalization
- Normalizes features across batch
- Example: `BatchNorm(in_channels)`
**LayerNorm** - Layer normalization
- Normalizes features per sample
- Example: `LayerNorm(in_channels)`
**InstanceNorm** - Instance normalization
- Normalizes per sample and graph
- Example: `InstanceNorm(in_channels)`
**GraphNorm** - Graph normalization
- Graph-specific normalization
- Example: `GraphNorm(in_channels)`
**PairNorm** - Pair normalization
- Prevents oversmoothing in deep GNNs
- Example: `PairNorm(scale_individually=False)`
**MessageNorm** - Message normalization
- Normalizes messages during passing
- Example: `MessageNorm(learn_scale=True)`
**DiffGroupNorm** - Differentiable Group Normalization
- Learnable grouping for normalization
- Example: `DiffGroupNorm(in_channels, groups=10)`
## Model Architectures
### Pre-Built Models
**GCN** - Complete Graph Convolutional Network
- Multi-layer GCN with dropout
- Example: `GCN(in_channels, hidden_channels, num_layers, out_channels)`
**GraphSAGE** - Complete GraphSAGE model
- Multi-layer SAGE with dropout
- Example: `GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)`
**GIN** - Complete Graph Isomorphism Network
- Multi-layer GIN for graph classification
- Example: `GIN(in_channels, hidden_channels, num_layers, out_channels)`
**GAT** - Complete Graph Attention Network
- Multi-layer GAT with attention
- Example: `GAT(in_channels, hidden_channels, num_layers, out_channels, heads=8)`
**PNA** - Principal Neighbourhood Aggregation
- Combines multiple aggregators and scalers
- Example: `PNA(in_channels, hidden_channels, num_layers, out_channels)`
**EdgeCNN** - Edge Convolution CNN
- Dynamic graph CNN for point clouds
- Example: `EdgeCNN(out_channels, num_layers=3, k=20)`
### Auto-Encoders
**GAE** - Graph Auto-Encoder
- Encodes graphs into latent space
- Example: `GAE(encoder)`
**VGAE** - Variational Graph Auto-Encoder
- Probabilistic graph encoding
- Example: `VGAE(encoder)`
**ARGA** - Adversarially Regularized Graph Auto-Encoder
- GAE with adversarial regularization
- Example: `ARGA(encoder, discriminator)`
**ARGVA** - Adversarially Regularized Variational Graph Auto-Encoder
- VGAE with adversarial regularization
- Example: `ARGVA(encoder, discriminator)`
### Knowledge Graph Embeddings
**TransE** - Translating embeddings
- Learns entity and relation embeddings
- Example: `TransE(num_nodes, num_relations, hidden_channels)`
**RotatE** - Rotational embeddings
- Embeddings in complex space
- Example: `RotatE(num_nodes, num_relations, hidden_channels)`
**ComplEx** - Complex embeddings
- Complex-valued embeddings
- Example: `ComplEx(num_nodes, num_relations, hidden_channels)`
**DistMult** - Bilinear diagonal model
- Simplified bilinear model
- Example: `DistMult(num_nodes, num_relations, hidden_channels)`
## Utility Layers
**Sequential** - Sequential container
- Chains multiple layers
- Example: `Sequential('x, edge_index', [(GCNConv(16, 64), 'x, edge_index -> x'), nn.ReLU()])`
**JumpingKnowledge** - Jumping knowledge connections
- Combines representations from all layers
- Modes: 'cat', 'max', 'lstm'
- Example: `JumpingKnowledge(mode='cat')`
**DeepGCNLayer** - Deep GCN layer wrapper
- Enables very deep GNNs with skip connections
- Example: `DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1)`
**MLP** - Multi-layer perceptron
- Standard feedforward network
- Example: `MLP([in_channels, 64, 64, out_channels], dropout=0.5)`
**Linear** - Lazy linear layer
- Linear transformation with lazy initialization
- Example: `Linear(in_channels, out_channels, bias=True)`
## Dense Layers
For dense (non-sparse) graph representations:
**DenseGCNConv** - Dense GCN layer
**DenseSAGEConv** - Dense SAGE layer
**DenseGINConv** - Dense GIN layer
**DenseGraphConv** - Dense graph convolution
These are useful when working with small, fully-connected, or densely represented graphs.
## Usage Tips
1. **Start simple**: Begin with GCNConv or GATConv for most tasks
2. **Consider data type**: Use molecular layers (SchNet, DimeNet) for 3D structures
3. **Check capabilities**: Match layer capabilities to your data (edge features, bipartite, etc.)
4. **Deep networks**: Use normalization (PairNorm, LayerNorm) and JumpingKnowledge for deep GNNs
5. **Large graphs**: Use scalable layers (SAGE, Cluster-GCN) with neighbor sampling
6. **Heterogeneous**: Use RGCNConv, HGTConv, or to_hetero() conversion
7. **Lazy initialization**: Use lazy layers when input dimensions vary or are unknown
## Common Patterns
### Basic GNN
```python
from torch_geometric.nn import GCNConv, global_mean_pool
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return global_mean_pool(x, batch)
```
### Deep GNN with Normalization
```python
class DeepGNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers, out_channels):
super().__init__()
self.convs = torch.nn.ModuleList()
self.norms = torch.nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
self.norms.append(LayerNorm(hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.norms.append(LayerNorm(hidden_channels))
self.convs.append(GCNConv(hidden_channels, out_channels))
self.jk = JumpingKnowledge(mode='cat')
def forward(self, x, edge_index, batch):
xs = []
for conv, norm in zip(self.convs[:-1], self.norms):
x = conv(x, edge_index)
x = norm(x)
x = F.relu(x)
xs.append(x)
x = self.convs[-1](x, edge_index)
xs.append(x)
x = self.jk(xs)
return global_mean_pool(x, batch)
```