13 KiB
Models and Architectures
Overview
TorchDrug provides a comprehensive collection of pre-built model architectures for various graph-based learning tasks. This reference catalogs all available models with their characteristics, use cases, and implementation details.
Graph Neural Networks
GCN (Graph Convolutional Network)
Type: Spatial message passing Paper: Semi-Supervised Classification with Graph Convolutional Networks (Kipf & Welling, 2017)
Characteristics:
- Simple and efficient aggregation
- Normalized adjacency matrix convolution
- Works well for homophilic graphs
- Good baseline for many tasks
Best For:
- Initial experiments and baselines
- When computational efficiency is important
- Graphs with clear local structure
Parameters:
input_dim: Node feature dimensionhidden_dims: List of hidden layer dimensionsedge_input_dim: Edge feature dimension (optional)batch_norm: Apply batch normalizationactivation: Activation function (relu, elu, etc.)dropout: Dropout rate
Use Cases:
- Molecular property prediction
- Citation network classification
- Social network analysis
GAT (Graph Attention Network)
Type: Attention-based message passing Paper: Graph Attention Networks (Veličković et al., 2018)
Characteristics:
- Learns attention weights for neighbors
- Different importance for different neighbors
- Multi-head attention for robustness
- Handles varying node degrees naturally
Best For:
- When neighbor importance varies
- Heterogeneous graphs
- Interpretable predictions
Parameters:
input_dim,hidden_dims: Standard dimensionsnum_heads: Number of attention headsnegative_slope: LeakyReLU slopeconcat: Concatenate or average multi-head outputs
Use Cases:
- Protein-protein interaction prediction
- Molecule generation with attention to reactive sites
- Knowledge graph reasoning with relation importance
GIN (Graph Isomorphism Network)
Type: Maximally powerful message passing Paper: How Powerful are Graph Neural Networks? (Xu et al., 2019)
Characteristics:
- Theoretically most expressive GNN architecture
- Injective aggregation function
- Can distinguish graph structures GCN cannot
- Often best performance on molecular tasks
Best For:
- Molecular property prediction (state-of-the-art)
- Tasks requiring structural discrimination
- Graph classification
Parameters:
input_dim,hidden_dims: Standard dimensionsedge_input_dim: Include edge featuresbatch_norm: Typically use truereadout: Graph pooling ("sum", "mean", "max")eps: Learnable or fixed epsilon
Use Cases:
- Drug property prediction (BBBP, HIV, etc.)
- Molecular generation
- Reaction prediction
RGCN (Relational Graph Convolutional Network)
Type: Multi-relational message passing Paper: Modeling Relational Data with Graph Convolutional Networks (Schlichtkrull et al., 2018)
Characteristics:
- Handles multiple edge/relation types
- Relation-specific weight matrices
- Basis decomposition for parameter efficiency
- Essential for knowledge graphs
Best For:
- Knowledge graph reasoning
- Heterogeneous molecular graphs
- Multi-relational data
Parameters:
num_relation: Number of relation typeshidden_dims: Layer dimensionsnum_bases: Basis decomposition (reduce parameters)
Use Cases:
- Knowledge graph completion
- Retrosynthesis (different bond types)
- Protein interaction networks
MPNN (Message Passing Neural Network)
Type: General message passing framework Paper: Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017)
Characteristics:
- Flexible message and update functions
- Edge features in message computation
- GRU updates for node hidden states
- Set2Set readout for graph representation
Best For:
- Quantum chemistry predictions
- Tasks with important edge information
- When node states evolve over multiple iterations
Parameters:
input_dim,hidden_dim: Feature dimensionsedge_input_dim: Edge feature dimensionnum_layer: Message passing iterationsnum_mlp_layer: MLP layers in message function
Use Cases:
- QM9 quantum property prediction
- Molecular dynamics
- 3D conformation-aware tasks
SchNet (Continuous-Filter Convolutional Network)
Type: 3D geometry-aware convolution Paper: SchNet: A continuous-filter convolutional neural network (Schütt et al., 2017)
Characteristics:
- Operates on 3D atomic coordinates
- Continuous filter convolutions
- Rotation and translation invariant
- Excellent for quantum chemistry
Best For:
- 3D molecular structure tasks
- Quantum property prediction
- Protein structure analysis
- Energy and force prediction
Parameters:
input_dim: Atom featureshidden_dims: Layer dimensionsnum_gaussian: RBF basis functions for distancescutoff: Interaction cutoff distance
Use Cases:
- QM9 property prediction
- Molecular dynamics simulations
- Protein-ligand binding with structures
- Crystal property prediction
ChebNet (Chebyshev Spectral CNN)
Type: Spectral convolution Paper: Convolutional Neural Networks on Graphs (Defferrard et al., 2016)
Characteristics:
- Spectral graph convolution
- Chebyshev polynomial approximation
- Captures global graph structure
- Computationally efficient
Best For:
- Tasks requiring global information
- When graph Laplacian is informative
- Theoretical analysis
Parameters:
input_dim,hidden_dims: Dimensionsnum_cheb: Order of Chebyshev polynomial
Use Cases:
- Citation network classification
- Brain network analysis
- Signal processing on graphs
NFP (Neural Fingerprint)
Type: Molecular fingerprint learning Paper: Convolutional Networks on Graphs for Learning Molecular Fingerprints (Duvenaud et al., 2015)
Characteristics:
- Learns differentiable molecular fingerprints
- Alternative to hand-crafted fingerprints (ECFP)
- Circular convolutions like ECFP
- Interpretable learned features
Best For:
- Molecular similarity learning
- Property prediction with limited data
- When interpretability is important
Parameters:
input_dim,output_dim: Feature dimensionshidden_dims: Layer dimensionsnum_layer: Circular convolution depth
Use Cases:
- Virtual screening
- Molecular similarity search
- QSAR modeling
Protein-Specific Models
GearNet (Geometry-Aware Relational Graph Network)
Type: Protein structure encoder Paper: Protein Representation Learning by Geometric Structure Pretraining (Zhang et al., 2023)
Characteristics:
- Incorporates 3D geometric information
- Multiple edge types (sequential, spatial, KNN)
- Designed specifically for proteins
- State-of-the-art on protein tasks
Best For:
- Protein structure prediction
- Protein function prediction
- Protein-protein interaction
- Any task with protein 3D structures
Parameters:
input_dim: Residue featureshidden_dims: Layer dimensionsnum_relation: Edge types (sequence, radius, KNN)edge_input_dim: Geometric features (distances, angles)batch_norm: Typically true
Use Cases:
- Enzyme function prediction (EnzymeCommission)
- Protein fold recognition
- Contact prediction
- Binding site identification
ESM (Evolutionary Scale Modeling)
Type: Protein language model (transformer) Paper: Biological structure and function emerge from scaling unsupervised learning (Rives et al., 2021)
Characteristics:
- Pre-trained on 250M+ protein sequences
- Captures evolutionary and structural information
- Transformer architecture
- Transfer learning for downstream tasks
Best For:
- Any sequence-based protein task
- When no structure available
- Transfer learning with limited data
Variants:
- ESM-1b: 650M parameters
- ESM-2: Multiple sizes (8M to 15B parameters)
Use Cases:
- Protein function prediction
- Variant effect prediction
- Protein design
- Structure prediction (ESMFold)
ProteinBERT
Type: Masked language model for proteins
Characteristics:
- BERT-style pre-training
- Masked amino acid prediction
- Bidirectional context
- Good for sequence-based tasks
Use Cases:
- Function annotation
- Subcellular localization
- Stability prediction
ProteinCNN / ProteinResNet
Type: Convolutional networks for sequences
Characteristics:
- 1D convolutions on sequences
- Local pattern recognition
- Faster than transformers
- Good for motif detection
Use Cases:
- Binding site prediction
- Secondary structure prediction
- Domain identification
ProteinLSTM
Type: Recurrent network for sequences
Characteristics:
- Bidirectional LSTM
- Captures long-range dependencies
- Sequential processing
- Good baseline for sequence tasks
Use Cases:
- Order prediction
- Sequential annotation
- Time-series protein data
Knowledge Graph Models
TransE (Translation Embedding)
Type: Translation-based embedding Paper: Translating Embeddings for Modeling Multi-relational Data (Bordes et al., 2013)
Characteristics:
- h + r ≈ t (head + relation ≈ tail)
- Simple and interpretable
- Works well for 1-to-1 relations
- Memory efficient
Best For:
- Large knowledge graphs
- Initial experiments
- Interpretable embeddings
Parameters:
num_entity,num_relation: Graph sizeembedding_dim: Embedding dimensions (typically 50-500)
RotatE (Rotation Embedding)
Type: Rotation in complex space Paper: RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (Sun et al., 2019)
Characteristics:
- Relations as rotations in complex space
- Handles symmetric, antisymmetric, inverse, composition
- State-of-the-art on many benchmarks
Best For:
- Most knowledge graph tasks
- Complex relation patterns
- When accuracy is critical
Parameters:
num_entity,num_relation: Graph sizeembedding_dim: Must be even (complex embeddings)max_score: Score clipping value
DistMult
Type: Bilinear model
Characteristics:
- Symmetric relation modeling
- Fast and efficient
- Cannot model antisymmetric relations
Best For:
- Symmetric relations (e.g., "similar to")
- When speed is critical
- Large-scale graphs
ComplEx
Type: Complex-valued embeddings
Characteristics:
- Handles asymmetric and symmetric relations
- Better than DistMult for most graphs
- Good balance of expressiveness and efficiency
Best For:
- General knowledge graph completion
- Mixed relation types
- When RotatE is too complex
SimplE
Type: Enhanced embedding model
Characteristics:
- Two embeddings per entity (canonical + inverse)
- Fully expressive
- Slightly more parameters than ComplEx
Best For:
- When full expressiveness needed
- Inverse relations are important
Generative Models
GraphAutoregressiveFlow
Type: Normalizing flow for molecules
Characteristics:
- Exact likelihood computation
- Invertible transformations
- Stable training (no adversarial)
- Conditional generation support
Best For:
- Molecular generation
- Density estimation
- Interpolation between molecules
Parameters:
input_dim: Atom featureshidden_dims: Coupling layersnum_flow: Number of flow transformations
Use Cases:
- De novo drug design
- Chemical space exploration
- Property-targeted generation
Pre-training Models
InfoGraph
Type: Contrastive learning
Characteristics:
- Maximizes mutual information
- Graph-level and node-level contrast
- Unsupervised pre-training
- Good for small datasets
Use Cases:
- Pre-train molecular encoders
- Few-shot learning
- Transfer learning
MultiviewContrast
Type: Multi-view contrastive learning for proteins
Characteristics:
- Contrasts different views of proteins
- Geometric pre-training
- Uses 3D structure information
- Excellent for protein models
Use Cases:
- Pre-train GearNet on protein structures
- Transfer to property prediction
- Limited labeled data scenarios
Model Selection Guide
By Task Type
Molecular Property Prediction:
- GIN (first choice)
- GAT (interpretability)
- SchNet (3D available)
Protein Tasks:
- ESM (sequence only)
- GearNet (structure available)
- ProteinBERT (sequence, lighter than ESM)
Knowledge Graphs:
- RotatE (best performance)
- ComplEx (good balance)
- TransE (large graphs, efficiency)
Molecular Generation:
- GraphAutoregressiveFlow (exact likelihood)
- GCPN with GIN backbone (property optimization)
Retrosynthesis:
- GIN (synthon completion)
- RGCN (center identification with bond types)
By Dataset Size
Small (< 1k):
- Use pre-trained models (ESM for proteins)
- Simpler architectures (GCN, ProteinCNN)
- Heavy regularization
Medium (1k-100k):
- GIN for molecules
- GAT for interpretability
- Standard training
Large (> 100k):
- Any model works
- Deeper architectures
- Can train from scratch
By Computational Budget
Low:
- GCN (simplest)
- DistMult (KG)
- ProteinLSTM
Medium:
- GIN
- GAT
- ComplEx
High:
- ESM (large)
- SchNet (3D)
- RotatE with high dim
Implementation Tips
- Start Simple: Begin with GCN or GIN baseline
- Use Pre-trained: ESM for proteins, InfoGraph for molecules
- Tune Depth: 3-5 layers typically sufficient
- Batch Normalization: Usually helps (except KG embeddings)
- Residual Connections: Important for deep networks
- Readout Function: "mean" usually works well
- Edge Features: Include when available (bonds, distances)
- Regularization: Dropout, weight decay, early stopping