Files
gh-k-dense-ai-claude-scient…/skills/torch_geometric/references/layers_reference.md
2025-11-30 08:30:10 +08:00

17 KiB

PyTorch Geometric Neural Network Layers Reference

This document provides a comprehensive reference of all neural network layers available in torch_geometric.nn.

Layer Capability Flags

When selecting layers, consider these capability flags:

  • SparseTensor: Supports torch_sparse.SparseTensor format for efficient sparse operations
  • edge_weight: Handles one-dimensional edge weight data
  • edge_attr: Processes multi-dimensional edge feature information
  • Bipartite: Works with bipartite graphs (different source/target node dimensions)
  • Static: Operates on static graphs with batched node features
  • Lazy: Enables initialization without specifying input channel dimensions

Convolutional Layers

Standard Graph Convolutions

GCNConv - Graph Convolutional Network layer

  • Implements spectral graph convolution with symmetric normalization
  • Supports: SparseTensor, edge_weight, Bipartite, Lazy
  • Use for: Citation networks, social networks, general graph learning
  • Example: GCNConv(in_channels, out_channels, improved=False, cached=True)

SAGEConv - GraphSAGE layer

  • Inductive learning via neighborhood sampling and aggregation
  • Supports: SparseTensor, Bipartite, Lazy
  • Use for: Large graphs, inductive learning, heterogeneous features
  • Example: SAGEConv(in_channels, out_channels, aggr='mean')

GATConv - Graph Attention Network layer

  • Multi-head attention mechanism for adaptive neighbor weighting
  • Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
  • Use for: Tasks requiring variable neighbor importance
  • Example: GATConv(in_channels, out_channels, heads=8, dropout=0.6)

GraphConv - Simple graph convolution (Morris et al.)

  • Basic message passing with optional edge weights
  • Supports: SparseTensor, edge_weight, Bipartite, Lazy
  • Use for: Baseline models, simple graph structures
  • Example: GraphConv(in_channels, out_channels, aggr='add')

GINConv - Graph Isomorphism Network layer

  • Maximally powerful GNN for graph isomorphism testing
  • Supports: Bipartite
  • Use for: Graph classification, molecular property prediction
  • Example: GINConv(nn.Sequential(nn.Linear(in_channels, out_channels), nn.ReLU()))

TransformerConv - Graph Transformer layer

  • Combines graph structure with transformer attention
  • Supports: SparseTensor, Bipartite, Lazy
  • Use for: Long-range dependencies, complex graphs
  • Example: TransformerConv(in_channels, out_channels, heads=8, beta=True)

ChebConv - Chebyshev spectral graph convolution

  • Uses Chebyshev polynomials for efficient spectral filtering
  • Supports: SparseTensor, edge_weight, Bipartite, Lazy
  • Use for: Spectral graph learning, efficient convolutions
  • Example: ChebConv(in_channels, out_channels, K=3)

SGConv - Simplified Graph Convolution

  • Pre-computes fixed number of propagation steps
  • Supports: SparseTensor, edge_weight, Bipartite, Lazy
  • Use for: Fast training, shallow models
  • Example: SGConv(in_channels, out_channels, K=2)

APPNP - Approximate Personalized Propagation of Neural Predictions

  • Separates feature transformation from propagation
  • Supports: SparseTensor, edge_weight, Lazy
  • Use for: Deep propagation without oversmoothing
  • Example: APPNP(K=10, alpha=0.1)

ARMAConv - ARMA graph convolution

  • Uses ARMA filters for graph filtering
  • Supports: SparseTensor, edge_weight, Bipartite, Lazy
  • Use for: Advanced spectral methods
  • Example: ARMAConv(in_channels, out_channels, num_stacks=3, num_layers=2)

GATv2Conv - Improved Graph Attention Network

  • Fixes static attention computation issue in GAT
  • Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
  • Use for: Better attention learning than original GAT
  • Example: GATv2Conv(in_channels, out_channels, heads=8)

SuperGATConv - Self-supervised Graph Attention

  • Adds self-supervised attention mechanism
  • Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
  • Use for: Self-supervised learning, limited labels
  • Example: SuperGATConv(in_channels, out_channels, heads=8)

GMMConv - Gaussian Mixture Model Convolution

  • Uses Gaussian kernels in pseudo-coordinate space
  • Supports: Bipartite
  • Use for: Point clouds, spatial data
  • Example: GMMConv(in_channels, out_channels, dim=3, kernel_size=5)

SplineConv - Spline-based convolution

  • B-spline basis functions for spatial filtering
  • Supports: Bipartite
  • Use for: Irregular grids, continuous spaces
  • Example: SplineConv(in_channels, out_channels, dim=2, kernel_size=5)

NNConv - Neural Network Convolution

  • Edge features processed by neural networks
  • Supports: edge_attr, Bipartite
  • Use for: Rich edge features, molecular graphs
  • Example: NNConv(in_channels, out_channels, nn=edge_nn, aggr='mean')

CGConv - Crystal Graph Convolution

  • Designed for crystalline materials
  • Supports: Bipartite
  • Use for: Materials science, crystal structures
  • Example: CGConv(in_channels, dim=3, batch_norm=True)

EdgeConv - Edge Convolution (Dynamic Graph CNN)

  • Dynamically computes edges based on feature space
  • Supports: Static
  • Use for: Point clouds, dynamic graphs
  • Example: EdgeConv(nn=edge_nn, aggr='max')

PointNetConv - PointNet++ convolution

  • Local and global feature learning for point clouds
  • Use for: 3D point cloud processing
  • Example: PointNetConv(local_nn, global_nn)

ResGatedGraphConv - Residual Gated Graph Convolution

  • Gating mechanism with residual connections
  • Supports: edge_attr, Bipartite, Lazy
  • Use for: Deep GNNs, complex features
  • Example: ResGatedGraphConv(in_channels, out_channels)

GENConv - Generalized Graph Convolution

  • Generalizes multiple GNN variants
  • Supports: SparseTensor, edge_weight, edge_attr, Bipartite, Lazy
  • Use for: Flexible architecture exploration
  • Example: GENConv(in_channels, out_channels, aggr='softmax', num_layers=2)

FiLMConv - Feature-wise Linear Modulation

  • Conditions on global features
  • Supports: Bipartite, Lazy
  • Use for: Conditional generation, multi-task learning
  • Example: FiLMConv(in_channels, out_channels, num_relations=5)

PANConv - Path Attention Network

  • Attention over multi-hop paths
  • Supports: SparseTensor, Lazy
  • Use for: Complex connectivity patterns
  • Example: PANConv(in_channels, out_channels, filter_size=3)

ClusterGCNConv - Cluster-GCN convolution

  • Efficient training via graph clustering
  • Supports: edge_attr, Lazy
  • Use for: Very large graphs
  • Example: ClusterGCNConv(in_channels, out_channels)

MFConv - Multi-scale Feature Convolution

  • Aggregates features at multiple scales
  • Supports: SparseTensor, Lazy
  • Use for: Multi-scale patterns
  • Example: MFConv(in_channels, out_channels)

RGCNConv - Relational Graph Convolution

  • Handles multiple edge types
  • Supports: SparseTensor, edge_weight, Lazy
  • Use for: Knowledge graphs, heterogeneous graphs
  • Example: RGCNConv(in_channels, out_channels, num_relations=10)

FAConv - Frequency Adaptive Convolution

  • Adaptive filtering in spectral domain
  • Supports: SparseTensor, Lazy
  • Use for: Spectral graph learning
  • Example: FAConv(in_channels, eps=0.1, dropout=0.5)

Molecular and 3D Convolutions

SchNet - Continuous-filter convolutional layer

  • Designed for molecular dynamics
  • Use for: Molecular property prediction, 3D molecules
  • Example: SchNet(hidden_channels=128, num_filters=64, num_interactions=6)

DimeNet - Directional Message Passing

  • Uses directional information and angles
  • Use for: 3D molecular structures, chemical properties
  • Example: DimeNet(hidden_channels=128, out_channels=1, num_blocks=6)

PointTransformerConv - Point cloud transformer

  • Transformer for 3D point clouds
  • Use for: 3D vision, point cloud segmentation
  • Example: PointTransformerConv(in_channels, out_channels)

Hypergraph Convolutions

HypergraphConv - Hypergraph convolution

  • Operates on hyperedges (edges connecting multiple nodes)
  • Supports: Lazy
  • Use for: Multi-way relationships, chemical reactions
  • Example: HypergraphConv(in_channels, out_channels)

HGTConv - Heterogeneous Graph Transformer

  • Transformer for heterogeneous graphs with multiple types
  • Supports: Lazy
  • Use for: Heterogeneous networks, knowledge graphs
  • Example: HGTConv(in_channels, out_channels, metadata, heads=8)

Aggregation Operators

Aggr - Base aggregation class

  • Flexible aggregation across nodes

SumAggregation - Sum aggregation

  • Example: SumAggregation()

MeanAggregation - Mean aggregation

  • Example: MeanAggregation()

MaxAggregation - Max aggregation

  • Example: MaxAggregation()

SoftmaxAggregation - Softmax-weighted aggregation

  • Learnable attention weights
  • Example: SoftmaxAggregation(learn=True)

PowerMeanAggregation - Power mean aggregation

  • Learnable power parameter
  • Example: PowerMeanAggregation(learn=True)

LSTMAggregation - LSTM-based aggregation

  • Sequential processing of neighbors
  • Example: LSTMAggregation(in_channels, out_channels)

SetTransformerAggregation - Set Transformer aggregation

  • Transformer for permutation-invariant aggregation
  • Example: SetTransformerAggregation(in_channels, out_channels)

MultiAggregation - Multiple aggregations

  • Combines multiple aggregation methods
  • Example: MultiAggregation(['mean', 'max', 'std'])

Pooling Layers

Global Pooling

global_mean_pool - Global mean pooling

  • Averages node features per graph
  • Example: global_mean_pool(x, batch)

global_max_pool - Global max pooling

  • Max over node features per graph
  • Example: global_max_pool(x, batch)

global_add_pool - Global sum pooling

  • Sums node features per graph
  • Example: global_add_pool(x, batch)

global_sort_pool - Global sort pooling

  • Sorts and concatenates top-k nodes
  • Example: global_sort_pool(x, batch, k=30)

GlobalAttention - Global attention pooling

  • Learnable attention weights for aggregation
  • Example: GlobalAttention(gate_nn)

Set2Set - Set2Set pooling

  • LSTM-based attention mechanism
  • Example: Set2Set(in_channels, processing_steps=3)

Hierarchical Pooling

TopKPooling - Top-k pooling

  • Keeps top-k nodes based on projection scores
  • Example: TopKPooling(in_channels, ratio=0.5)

SAGPooling - Self-Attention Graph Pooling

  • Uses self-attention for node selection
  • Example: SAGPooling(in_channels, ratio=0.5)

ASAPooling - Adaptive Structure Aware Pooling

  • Structure-aware node selection
  • Example: ASAPooling(in_channels, ratio=0.5)

PANPooling - Path Attention Pooling

  • Attention over paths for pooling
  • Example: PANPooling(in_channels, ratio=0.5)

EdgePooling - Edge contraction pooling

  • Pools by contracting edges
  • Example: EdgePooling(in_channels)

MemPooling - Memory-based pooling

  • Learnable cluster assignments
  • Example: MemPooling(in_channels, out_channels, heads=4, num_clusters=10)

avg_pool / max_pool - Average/Max pool with clustering

  • Pools nodes within clusters
  • Example: avg_pool(cluster, data)

Normalization Layers

BatchNorm - Batch normalization

  • Normalizes features across batch
  • Example: BatchNorm(in_channels)

LayerNorm - Layer normalization

  • Normalizes features per sample
  • Example: LayerNorm(in_channels)

InstanceNorm - Instance normalization

  • Normalizes per sample and graph
  • Example: InstanceNorm(in_channels)

GraphNorm - Graph normalization

  • Graph-specific normalization
  • Example: GraphNorm(in_channels)

PairNorm - Pair normalization

  • Prevents oversmoothing in deep GNNs
  • Example: PairNorm(scale_individually=False)

MessageNorm - Message normalization

  • Normalizes messages during passing
  • Example: MessageNorm(learn_scale=True)

DiffGroupNorm - Differentiable Group Normalization

  • Learnable grouping for normalization
  • Example: DiffGroupNorm(in_channels, groups=10)

Model Architectures

Pre-Built Models

GCN - Complete Graph Convolutional Network

  • Multi-layer GCN with dropout
  • Example: GCN(in_channels, hidden_channels, num_layers, out_channels)

GraphSAGE - Complete GraphSAGE model

  • Multi-layer SAGE with dropout
  • Example: GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)

GIN - Complete Graph Isomorphism Network

  • Multi-layer GIN for graph classification
  • Example: GIN(in_channels, hidden_channels, num_layers, out_channels)

GAT - Complete Graph Attention Network

  • Multi-layer GAT with attention
  • Example: GAT(in_channels, hidden_channels, num_layers, out_channels, heads=8)

PNA - Principal Neighbourhood Aggregation

  • Combines multiple aggregators and scalers
  • Example: PNA(in_channels, hidden_channels, num_layers, out_channels)

EdgeCNN - Edge Convolution CNN

  • Dynamic graph CNN for point clouds
  • Example: EdgeCNN(out_channels, num_layers=3, k=20)

Auto-Encoders

GAE - Graph Auto-Encoder

  • Encodes graphs into latent space
  • Example: GAE(encoder)

VGAE - Variational Graph Auto-Encoder

  • Probabilistic graph encoding
  • Example: VGAE(encoder)

ARGA - Adversarially Regularized Graph Auto-Encoder

  • GAE with adversarial regularization
  • Example: ARGA(encoder, discriminator)

ARGVA - Adversarially Regularized Variational Graph Auto-Encoder

  • VGAE with adversarial regularization
  • Example: ARGVA(encoder, discriminator)

Knowledge Graph Embeddings

TransE - Translating embeddings

  • Learns entity and relation embeddings
  • Example: TransE(num_nodes, num_relations, hidden_channels)

RotatE - Rotational embeddings

  • Embeddings in complex space
  • Example: RotatE(num_nodes, num_relations, hidden_channels)

ComplEx - Complex embeddings

  • Complex-valued embeddings
  • Example: ComplEx(num_nodes, num_relations, hidden_channels)

DistMult - Bilinear diagonal model

  • Simplified bilinear model
  • Example: DistMult(num_nodes, num_relations, hidden_channels)

Utility Layers

Sequential - Sequential container

  • Chains multiple layers
  • Example: Sequential('x, edge_index', [(GCNConv(16, 64), 'x, edge_index -> x'), nn.ReLU()])

JumpingKnowledge - Jumping knowledge connections

  • Combines representations from all layers
  • Modes: 'cat', 'max', 'lstm'
  • Example: JumpingKnowledge(mode='cat')

DeepGCNLayer - Deep GCN layer wrapper

  • Enables very deep GNNs with skip connections
  • Example: DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1)

MLP - Multi-layer perceptron

  • Standard feedforward network
  • Example: MLP([in_channels, 64, 64, out_channels], dropout=0.5)

Linear - Lazy linear layer

  • Linear transformation with lazy initialization
  • Example: Linear(in_channels, out_channels, bias=True)

Dense Layers

For dense (non-sparse) graph representations:

DenseGCNConv - Dense GCN layer DenseSAGEConv - Dense SAGE layer DenseGINConv - Dense GIN layer DenseGraphConv - Dense graph convolution

These are useful when working with small, fully-connected, or densely represented graphs.

Usage Tips

  1. Start simple: Begin with GCNConv or GATConv for most tasks
  2. Consider data type: Use molecular layers (SchNet, DimeNet) for 3D structures
  3. Check capabilities: Match layer capabilities to your data (edge features, bipartite, etc.)
  4. Deep networks: Use normalization (PairNorm, LayerNorm) and JumpingKnowledge for deep GNNs
  5. Large graphs: Use scalable layers (SAGE, Cluster-GCN) with neighbor sampling
  6. Heterogeneous: Use RGCNConv, HGTConv, or to_hetero() conversion
  7. Lazy initialization: Use lazy layers when input dimensions vary or are unknown

Common Patterns

Basic GNN

from torch_geometric.nn import GCNConv, global_mean_pool

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return global_mean_pool(x, batch)

Deep GNN with Normalization

class DeepGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, out_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.norms.append(LayerNorm(hidden_channels))

        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.norms.append(LayerNorm(hidden_channels))

        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.jk = JumpingKnowledge(mode='cat')

    def forward(self, x, edge_index, batch):
        xs = []
        for conv, norm in zip(self.convs[:-1], self.norms):
            x = conv(x, edge_index)
            x = norm(x)
            x = F.relu(x)
            xs.append(x)

        x = self.convs[-1](x, edge_index)
        xs.append(x)

        x = self.jk(xs)
        return global_mean_pool(x, batch)