Initial commit
This commit is contained in:
485
skills/torch_geometric/references/layers_reference.md
Normal file
485
skills/torch_geometric/references/layers_reference.md
Normal file
@@ -0,0 +1,485 @@
|
||||
# PyTorch Geometric Neural Network Layers Reference
|
||||
|
||||
This document provides a comprehensive reference of all neural network layers available in `torch_geometric.nn`.
|
||||
|
||||
## Layer Capability Flags
|
||||
|
||||
When selecting layers, consider these capability flags:
|
||||
|
||||
- **SparseTensor**: Supports `torch_sparse.SparseTensor` format for efficient sparse operations
|
||||
- **edge_weight**: Handles one-dimensional edge weight data
|
||||
- **edge_attr**: Processes multi-dimensional edge feature information
|
||||
- **Bipartite**: Works with bipartite graphs (different source/target node dimensions)
|
||||
- **Static**: Operates on static graphs with batched node features
|
||||
- **Lazy**: Enables initialization without specifying input channel dimensions
|
||||
|
||||
## Convolutional Layers
|
||||
|
||||
### Standard Graph Convolutions
|
||||
|
||||
**GCNConv** - Graph Convolutional Network layer
|
||||
- Implements spectral graph convolution with symmetric normalization
|
||||
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
|
||||
- Use for: Citation networks, social networks, general graph learning
|
||||
- Example: `GCNConv(in_channels, out_channels, improved=False, cached=True)`
|
||||
|
||||
**SAGEConv** - GraphSAGE layer
|
||||
- Inductive learning via neighborhood sampling and aggregation
|
||||
- Supports: SparseTensor, Bipartite, Lazy
|
||||
- Use for: Large graphs, inductive learning, heterogeneous features
|
||||
- Example: `SAGEConv(in_channels, out_channels, aggr='mean')`
|
||||
|
||||
**GATConv** - Graph Attention Network layer
|
||||
- Multi-head attention mechanism for adaptive neighbor weighting
|
||||
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
|
||||
- Use for: Tasks requiring variable neighbor importance
|
||||
- Example: `GATConv(in_channels, out_channels, heads=8, dropout=0.6)`
|
||||
|
||||
**GraphConv** - Simple graph convolution (Morris et al.)
|
||||
- Basic message passing with optional edge weights
|
||||
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
|
||||
- Use for: Baseline models, simple graph structures
|
||||
- Example: `GraphConv(in_channels, out_channels, aggr='add')`
|
||||
|
||||
**GINConv** - Graph Isomorphism Network layer
|
||||
- Maximally powerful GNN for graph isomorphism testing
|
||||
- Supports: Bipartite
|
||||
- Use for: Graph classification, molecular property prediction
|
||||
- Example: `GINConv(nn.Sequential(nn.Linear(in_channels, out_channels), nn.ReLU()))`
|
||||
|
||||
**TransformerConv** - Graph Transformer layer
|
||||
- Combines graph structure with transformer attention
|
||||
- Supports: SparseTensor, Bipartite, Lazy
|
||||
- Use for: Long-range dependencies, complex graphs
|
||||
- Example: `TransformerConv(in_channels, out_channels, heads=8, beta=True)`
|
||||
|
||||
**ChebConv** - Chebyshev spectral graph convolution
|
||||
- Uses Chebyshev polynomials for efficient spectral filtering
|
||||
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
|
||||
- Use for: Spectral graph learning, efficient convolutions
|
||||
- Example: `ChebConv(in_channels, out_channels, K=3)`
|
||||
|
||||
**SGConv** - Simplified Graph Convolution
|
||||
- Pre-computes fixed number of propagation steps
|
||||
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
|
||||
- Use for: Fast training, shallow models
|
||||
- Example: `SGConv(in_channels, out_channels, K=2)`
|
||||
|
||||
**APPNP** - Approximate Personalized Propagation of Neural Predictions
|
||||
- Separates feature transformation from propagation
|
||||
- Supports: SparseTensor, edge_weight, Lazy
|
||||
- Use for: Deep propagation without oversmoothing
|
||||
- Example: `APPNP(K=10, alpha=0.1)`
|
||||
|
||||
**ARMAConv** - ARMA graph convolution
|
||||
- Uses ARMA filters for graph filtering
|
||||
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
|
||||
- Use for: Advanced spectral methods
|
||||
- Example: `ARMAConv(in_channels, out_channels, num_stacks=3, num_layers=2)`
|
||||
|
||||
**GATv2Conv** - Improved Graph Attention Network
|
||||
- Fixes static attention computation issue in GAT
|
||||
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
|
||||
- Use for: Better attention learning than original GAT
|
||||
- Example: `GATv2Conv(in_channels, out_channels, heads=8)`
|
||||
|
||||
**SuperGATConv** - Self-supervised Graph Attention
|
||||
- Adds self-supervised attention mechanism
|
||||
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
|
||||
- Use for: Self-supervised learning, limited labels
|
||||
- Example: `SuperGATConv(in_channels, out_channels, heads=8)`
|
||||
|
||||
**GMMConv** - Gaussian Mixture Model Convolution
|
||||
- Uses Gaussian kernels in pseudo-coordinate space
|
||||
- Supports: Bipartite
|
||||
- Use for: Point clouds, spatial data
|
||||
- Example: `GMMConv(in_channels, out_channels, dim=3, kernel_size=5)`
|
||||
|
||||
**SplineConv** - Spline-based convolution
|
||||
- B-spline basis functions for spatial filtering
|
||||
- Supports: Bipartite
|
||||
- Use for: Irregular grids, continuous spaces
|
||||
- Example: `SplineConv(in_channels, out_channels, dim=2, kernel_size=5)`
|
||||
|
||||
**NNConv** - Neural Network Convolution
|
||||
- Edge features processed by neural networks
|
||||
- Supports: edge_attr, Bipartite
|
||||
- Use for: Rich edge features, molecular graphs
|
||||
- Example: `NNConv(in_channels, out_channels, nn=edge_nn, aggr='mean')`
|
||||
|
||||
**CGConv** - Crystal Graph Convolution
|
||||
- Designed for crystalline materials
|
||||
- Supports: Bipartite
|
||||
- Use for: Materials science, crystal structures
|
||||
- Example: `CGConv(in_channels, dim=3, batch_norm=True)`
|
||||
|
||||
**EdgeConv** - Edge Convolution (Dynamic Graph CNN)
|
||||
- Dynamically computes edges based on feature space
|
||||
- Supports: Static
|
||||
- Use for: Point clouds, dynamic graphs
|
||||
- Example: `EdgeConv(nn=edge_nn, aggr='max')`
|
||||
|
||||
**PointNetConv** - PointNet++ convolution
|
||||
- Local and global feature learning for point clouds
|
||||
- Use for: 3D point cloud processing
|
||||
- Example: `PointNetConv(local_nn, global_nn)`
|
||||
|
||||
**ResGatedGraphConv** - Residual Gated Graph Convolution
|
||||
- Gating mechanism with residual connections
|
||||
- Supports: edge_attr, Bipartite, Lazy
|
||||
- Use for: Deep GNNs, complex features
|
||||
- Example: `ResGatedGraphConv(in_channels, out_channels)`
|
||||
|
||||
**GENConv** - Generalized Graph Convolution
|
||||
- Generalizes multiple GNN variants
|
||||
- Supports: SparseTensor, edge_weight, edge_attr, Bipartite, Lazy
|
||||
- Use for: Flexible architecture exploration
|
||||
- Example: `GENConv(in_channels, out_channels, aggr='softmax', num_layers=2)`
|
||||
|
||||
**FiLMConv** - Feature-wise Linear Modulation
|
||||
- Conditions on global features
|
||||
- Supports: Bipartite, Lazy
|
||||
- Use for: Conditional generation, multi-task learning
|
||||
- Example: `FiLMConv(in_channels, out_channels, num_relations=5)`
|
||||
|
||||
**PANConv** - Path Attention Network
|
||||
- Attention over multi-hop paths
|
||||
- Supports: SparseTensor, Lazy
|
||||
- Use for: Complex connectivity patterns
|
||||
- Example: `PANConv(in_channels, out_channels, filter_size=3)`
|
||||
|
||||
**ClusterGCNConv** - Cluster-GCN convolution
|
||||
- Efficient training via graph clustering
|
||||
- Supports: edge_attr, Lazy
|
||||
- Use for: Very large graphs
|
||||
- Example: `ClusterGCNConv(in_channels, out_channels)`
|
||||
|
||||
**MFConv** - Multi-scale Feature Convolution
|
||||
- Aggregates features at multiple scales
|
||||
- Supports: SparseTensor, Lazy
|
||||
- Use for: Multi-scale patterns
|
||||
- Example: `MFConv(in_channels, out_channels)`
|
||||
|
||||
**RGCNConv** - Relational Graph Convolution
|
||||
- Handles multiple edge types
|
||||
- Supports: SparseTensor, edge_weight, Lazy
|
||||
- Use for: Knowledge graphs, heterogeneous graphs
|
||||
- Example: `RGCNConv(in_channels, out_channels, num_relations=10)`
|
||||
|
||||
**FAConv** - Frequency Adaptive Convolution
|
||||
- Adaptive filtering in spectral domain
|
||||
- Supports: SparseTensor, Lazy
|
||||
- Use for: Spectral graph learning
|
||||
- Example: `FAConv(in_channels, eps=0.1, dropout=0.5)`
|
||||
|
||||
### Molecular and 3D Convolutions
|
||||
|
||||
**SchNet** - Continuous-filter convolutional layer
|
||||
- Designed for molecular dynamics
|
||||
- Use for: Molecular property prediction, 3D molecules
|
||||
- Example: `SchNet(hidden_channels=128, num_filters=64, num_interactions=6)`
|
||||
|
||||
**DimeNet** - Directional Message Passing
|
||||
- Uses directional information and angles
|
||||
- Use for: 3D molecular structures, chemical properties
|
||||
- Example: `DimeNet(hidden_channels=128, out_channels=1, num_blocks=6)`
|
||||
|
||||
**PointTransformerConv** - Point cloud transformer
|
||||
- Transformer for 3D point clouds
|
||||
- Use for: 3D vision, point cloud segmentation
|
||||
- Example: `PointTransformerConv(in_channels, out_channels)`
|
||||
|
||||
### Hypergraph Convolutions
|
||||
|
||||
**HypergraphConv** - Hypergraph convolution
|
||||
- Operates on hyperedges (edges connecting multiple nodes)
|
||||
- Supports: Lazy
|
||||
- Use for: Multi-way relationships, chemical reactions
|
||||
- Example: `HypergraphConv(in_channels, out_channels)`
|
||||
|
||||
**HGTConv** - Heterogeneous Graph Transformer
|
||||
- Transformer for heterogeneous graphs with multiple types
|
||||
- Supports: Lazy
|
||||
- Use for: Heterogeneous networks, knowledge graphs
|
||||
- Example: `HGTConv(in_channels, out_channels, metadata, heads=8)`
|
||||
|
||||
## Aggregation Operators
|
||||
|
||||
**Aggr** - Base aggregation class
|
||||
- Flexible aggregation across nodes
|
||||
|
||||
**SumAggregation** - Sum aggregation
|
||||
- Example: `SumAggregation()`
|
||||
|
||||
**MeanAggregation** - Mean aggregation
|
||||
- Example: `MeanAggregation()`
|
||||
|
||||
**MaxAggregation** - Max aggregation
|
||||
- Example: `MaxAggregation()`
|
||||
|
||||
**SoftmaxAggregation** - Softmax-weighted aggregation
|
||||
- Learnable attention weights
|
||||
- Example: `SoftmaxAggregation(learn=True)`
|
||||
|
||||
**PowerMeanAggregation** - Power mean aggregation
|
||||
- Learnable power parameter
|
||||
- Example: `PowerMeanAggregation(learn=True)`
|
||||
|
||||
**LSTMAggregation** - LSTM-based aggregation
|
||||
- Sequential processing of neighbors
|
||||
- Example: `LSTMAggregation(in_channels, out_channels)`
|
||||
|
||||
**SetTransformerAggregation** - Set Transformer aggregation
|
||||
- Transformer for permutation-invariant aggregation
|
||||
- Example: `SetTransformerAggregation(in_channels, out_channels)`
|
||||
|
||||
**MultiAggregation** - Multiple aggregations
|
||||
- Combines multiple aggregation methods
|
||||
- Example: `MultiAggregation(['mean', 'max', 'std'])`
|
||||
|
||||
## Pooling Layers
|
||||
|
||||
### Global Pooling
|
||||
|
||||
**global_mean_pool** - Global mean pooling
|
||||
- Averages node features per graph
|
||||
- Example: `global_mean_pool(x, batch)`
|
||||
|
||||
**global_max_pool** - Global max pooling
|
||||
- Max over node features per graph
|
||||
- Example: `global_max_pool(x, batch)`
|
||||
|
||||
**global_add_pool** - Global sum pooling
|
||||
- Sums node features per graph
|
||||
- Example: `global_add_pool(x, batch)`
|
||||
|
||||
**global_sort_pool** - Global sort pooling
|
||||
- Sorts and concatenates top-k nodes
|
||||
- Example: `global_sort_pool(x, batch, k=30)`
|
||||
|
||||
**GlobalAttention** - Global attention pooling
|
||||
- Learnable attention weights for aggregation
|
||||
- Example: `GlobalAttention(gate_nn)`
|
||||
|
||||
**Set2Set** - Set2Set pooling
|
||||
- LSTM-based attention mechanism
|
||||
- Example: `Set2Set(in_channels, processing_steps=3)`
|
||||
|
||||
### Hierarchical Pooling
|
||||
|
||||
**TopKPooling** - Top-k pooling
|
||||
- Keeps top-k nodes based on projection scores
|
||||
- Example: `TopKPooling(in_channels, ratio=0.5)`
|
||||
|
||||
**SAGPooling** - Self-Attention Graph Pooling
|
||||
- Uses self-attention for node selection
|
||||
- Example: `SAGPooling(in_channels, ratio=0.5)`
|
||||
|
||||
**ASAPooling** - Adaptive Structure Aware Pooling
|
||||
- Structure-aware node selection
|
||||
- Example: `ASAPooling(in_channels, ratio=0.5)`
|
||||
|
||||
**PANPooling** - Path Attention Pooling
|
||||
- Attention over paths for pooling
|
||||
- Example: `PANPooling(in_channels, ratio=0.5)`
|
||||
|
||||
**EdgePooling** - Edge contraction pooling
|
||||
- Pools by contracting edges
|
||||
- Example: `EdgePooling(in_channels)`
|
||||
|
||||
**MemPooling** - Memory-based pooling
|
||||
- Learnable cluster assignments
|
||||
- Example: `MemPooling(in_channels, out_channels, heads=4, num_clusters=10)`
|
||||
|
||||
**avg_pool** / **max_pool** - Average/Max pool with clustering
|
||||
- Pools nodes within clusters
|
||||
- Example: `avg_pool(cluster, data)`
|
||||
|
||||
## Normalization Layers
|
||||
|
||||
**BatchNorm** - Batch normalization
|
||||
- Normalizes features across batch
|
||||
- Example: `BatchNorm(in_channels)`
|
||||
|
||||
**LayerNorm** - Layer normalization
|
||||
- Normalizes features per sample
|
||||
- Example: `LayerNorm(in_channels)`
|
||||
|
||||
**InstanceNorm** - Instance normalization
|
||||
- Normalizes per sample and graph
|
||||
- Example: `InstanceNorm(in_channels)`
|
||||
|
||||
**GraphNorm** - Graph normalization
|
||||
- Graph-specific normalization
|
||||
- Example: `GraphNorm(in_channels)`
|
||||
|
||||
**PairNorm** - Pair normalization
|
||||
- Prevents oversmoothing in deep GNNs
|
||||
- Example: `PairNorm(scale_individually=False)`
|
||||
|
||||
**MessageNorm** - Message normalization
|
||||
- Normalizes messages during passing
|
||||
- Example: `MessageNorm(learn_scale=True)`
|
||||
|
||||
**DiffGroupNorm** - Differentiable Group Normalization
|
||||
- Learnable grouping for normalization
|
||||
- Example: `DiffGroupNorm(in_channels, groups=10)`
|
||||
|
||||
## Model Architectures
|
||||
|
||||
### Pre-Built Models
|
||||
|
||||
**GCN** - Complete Graph Convolutional Network
|
||||
- Multi-layer GCN with dropout
|
||||
- Example: `GCN(in_channels, hidden_channels, num_layers, out_channels)`
|
||||
|
||||
**GraphSAGE** - Complete GraphSAGE model
|
||||
- Multi-layer SAGE with dropout
|
||||
- Example: `GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)`
|
||||
|
||||
**GIN** - Complete Graph Isomorphism Network
|
||||
- Multi-layer GIN for graph classification
|
||||
- Example: `GIN(in_channels, hidden_channels, num_layers, out_channels)`
|
||||
|
||||
**GAT** - Complete Graph Attention Network
|
||||
- Multi-layer GAT with attention
|
||||
- Example: `GAT(in_channels, hidden_channels, num_layers, out_channels, heads=8)`
|
||||
|
||||
**PNA** - Principal Neighbourhood Aggregation
|
||||
- Combines multiple aggregators and scalers
|
||||
- Example: `PNA(in_channels, hidden_channels, num_layers, out_channels)`
|
||||
|
||||
**EdgeCNN** - Edge Convolution CNN
|
||||
- Dynamic graph CNN for point clouds
|
||||
- Example: `EdgeCNN(out_channels, num_layers=3, k=20)`
|
||||
|
||||
### Auto-Encoders
|
||||
|
||||
**GAE** - Graph Auto-Encoder
|
||||
- Encodes graphs into latent space
|
||||
- Example: `GAE(encoder)`
|
||||
|
||||
**VGAE** - Variational Graph Auto-Encoder
|
||||
- Probabilistic graph encoding
|
||||
- Example: `VGAE(encoder)`
|
||||
|
||||
**ARGA** - Adversarially Regularized Graph Auto-Encoder
|
||||
- GAE with adversarial regularization
|
||||
- Example: `ARGA(encoder, discriminator)`
|
||||
|
||||
**ARGVA** - Adversarially Regularized Variational Graph Auto-Encoder
|
||||
- VGAE with adversarial regularization
|
||||
- Example: `ARGVA(encoder, discriminator)`
|
||||
|
||||
### Knowledge Graph Embeddings
|
||||
|
||||
**TransE** - Translating embeddings
|
||||
- Learns entity and relation embeddings
|
||||
- Example: `TransE(num_nodes, num_relations, hidden_channels)`
|
||||
|
||||
**RotatE** - Rotational embeddings
|
||||
- Embeddings in complex space
|
||||
- Example: `RotatE(num_nodes, num_relations, hidden_channels)`
|
||||
|
||||
**ComplEx** - Complex embeddings
|
||||
- Complex-valued embeddings
|
||||
- Example: `ComplEx(num_nodes, num_relations, hidden_channels)`
|
||||
|
||||
**DistMult** - Bilinear diagonal model
|
||||
- Simplified bilinear model
|
||||
- Example: `DistMult(num_nodes, num_relations, hidden_channels)`
|
||||
|
||||
## Utility Layers
|
||||
|
||||
**Sequential** - Sequential container
|
||||
- Chains multiple layers
|
||||
- Example: `Sequential('x, edge_index', [(GCNConv(16, 64), 'x, edge_index -> x'), nn.ReLU()])`
|
||||
|
||||
**JumpingKnowledge** - Jumping knowledge connections
|
||||
- Combines representations from all layers
|
||||
- Modes: 'cat', 'max', 'lstm'
|
||||
- Example: `JumpingKnowledge(mode='cat')`
|
||||
|
||||
**DeepGCNLayer** - Deep GCN layer wrapper
|
||||
- Enables very deep GNNs with skip connections
|
||||
- Example: `DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1)`
|
||||
|
||||
**MLP** - Multi-layer perceptron
|
||||
- Standard feedforward network
|
||||
- Example: `MLP([in_channels, 64, 64, out_channels], dropout=0.5)`
|
||||
|
||||
**Linear** - Lazy linear layer
|
||||
- Linear transformation with lazy initialization
|
||||
- Example: `Linear(in_channels, out_channels, bias=True)`
|
||||
|
||||
## Dense Layers
|
||||
|
||||
For dense (non-sparse) graph representations:
|
||||
|
||||
**DenseGCNConv** - Dense GCN layer
|
||||
**DenseSAGEConv** - Dense SAGE layer
|
||||
**DenseGINConv** - Dense GIN layer
|
||||
**DenseGraphConv** - Dense graph convolution
|
||||
|
||||
These are useful when working with small, fully-connected, or densely represented graphs.
|
||||
|
||||
## Usage Tips
|
||||
|
||||
1. **Start simple**: Begin with GCNConv or GATConv for most tasks
|
||||
2. **Consider data type**: Use molecular layers (SchNet, DimeNet) for 3D structures
|
||||
3. **Check capabilities**: Match layer capabilities to your data (edge features, bipartite, etc.)
|
||||
4. **Deep networks**: Use normalization (PairNorm, LayerNorm) and JumpingKnowledge for deep GNNs
|
||||
5. **Large graphs**: Use scalable layers (SAGE, Cluster-GCN) with neighbor sampling
|
||||
6. **Heterogeneous**: Use RGCNConv, HGTConv, or to_hetero() conversion
|
||||
7. **Lazy initialization**: Use lazy layers when input dimensions vary or are unknown
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Basic GNN
|
||||
```python
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
|
||||
class GNN(torch.nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv1 = GCNConv(in_channels, hidden_channels)
|
||||
self.conv2 = GCNConv(hidden_channels, out_channels)
|
||||
|
||||
def forward(self, x, edge_index, batch):
|
||||
x = self.conv1(x, edge_index).relu()
|
||||
x = self.conv2(x, edge_index)
|
||||
return global_mean_pool(x, batch)
|
||||
```
|
||||
|
||||
### Deep GNN with Normalization
|
||||
```python
|
||||
class DeepGNN(torch.nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, num_layers, out_channels):
|
||||
super().__init__()
|
||||
self.convs = torch.nn.ModuleList()
|
||||
self.norms = torch.nn.ModuleList()
|
||||
|
||||
self.convs.append(GCNConv(in_channels, hidden_channels))
|
||||
self.norms.append(LayerNorm(hidden_channels))
|
||||
|
||||
for _ in range(num_layers - 2):
|
||||
self.convs.append(GCNConv(hidden_channels, hidden_channels))
|
||||
self.norms.append(LayerNorm(hidden_channels))
|
||||
|
||||
self.convs.append(GCNConv(hidden_channels, out_channels))
|
||||
self.jk = JumpingKnowledge(mode='cat')
|
||||
|
||||
def forward(self, x, edge_index, batch):
|
||||
xs = []
|
||||
for conv, norm in zip(self.convs[:-1], self.norms):
|
||||
x = conv(x, edge_index)
|
||||
x = norm(x)
|
||||
x = F.relu(x)
|
||||
xs.append(x)
|
||||
|
||||
x = self.convs[-1](x, edge_index)
|
||||
xs.append(x)
|
||||
|
||||
x = self.jk(xs)
|
||||
return global_mean_pool(x, batch)
|
||||
```
|
||||
Reference in New Issue
Block a user