542 lines
13 KiB
Markdown
542 lines
13 KiB
Markdown
# Models and Architectures
|
|
|
|
## Overview
|
|
|
|
TorchDrug provides a comprehensive collection of pre-built model architectures for various graph-based learning tasks. This reference catalogs all available models with their characteristics, use cases, and implementation details.
|
|
|
|
## Graph Neural Networks
|
|
|
|
### GCN (Graph Convolutional Network)
|
|
|
|
**Type:** Spatial message passing
|
|
**Paper:** Semi-Supervised Classification with Graph Convolutional Networks (Kipf & Welling, 2017)
|
|
|
|
**Characteristics:**
|
|
- Simple and efficient aggregation
|
|
- Normalized adjacency matrix convolution
|
|
- Works well for homophilic graphs
|
|
- Good baseline for many tasks
|
|
|
|
**Best For:**
|
|
- Initial experiments and baselines
|
|
- When computational efficiency is important
|
|
- Graphs with clear local structure
|
|
|
|
**Parameters:**
|
|
- `input_dim`: Node feature dimension
|
|
- `hidden_dims`: List of hidden layer dimensions
|
|
- `edge_input_dim`: Edge feature dimension (optional)
|
|
- `batch_norm`: Apply batch normalization
|
|
- `activation`: Activation function (relu, elu, etc.)
|
|
- `dropout`: Dropout rate
|
|
|
|
**Use Cases:**
|
|
- Molecular property prediction
|
|
- Citation network classification
|
|
- Social network analysis
|
|
|
|
### GAT (Graph Attention Network)
|
|
|
|
**Type:** Attention-based message passing
|
|
**Paper:** Graph Attention Networks (Veličković et al., 2018)
|
|
|
|
**Characteristics:**
|
|
- Learns attention weights for neighbors
|
|
- Different importance for different neighbors
|
|
- Multi-head attention for robustness
|
|
- Handles varying node degrees naturally
|
|
|
|
**Best For:**
|
|
- When neighbor importance varies
|
|
- Heterogeneous graphs
|
|
- Interpretable predictions
|
|
|
|
**Parameters:**
|
|
- `input_dim`, `hidden_dims`: Standard dimensions
|
|
- `num_heads`: Number of attention heads
|
|
- `negative_slope`: LeakyReLU slope
|
|
- `concat`: Concatenate or average multi-head outputs
|
|
|
|
**Use Cases:**
|
|
- Protein-protein interaction prediction
|
|
- Molecule generation with attention to reactive sites
|
|
- Knowledge graph reasoning with relation importance
|
|
|
|
### GIN (Graph Isomorphism Network)
|
|
|
|
**Type:** Maximally powerful message passing
|
|
**Paper:** How Powerful are Graph Neural Networks? (Xu et al., 2019)
|
|
|
|
**Characteristics:**
|
|
- Theoretically most expressive GNN architecture
|
|
- Injective aggregation function
|
|
- Can distinguish graph structures GCN cannot
|
|
- Often best performance on molecular tasks
|
|
|
|
**Best For:**
|
|
- Molecular property prediction (state-of-the-art)
|
|
- Tasks requiring structural discrimination
|
|
- Graph classification
|
|
|
|
**Parameters:**
|
|
- `input_dim`, `hidden_dims`: Standard dimensions
|
|
- `edge_input_dim`: Include edge features
|
|
- `batch_norm`: Typically use true
|
|
- `readout`: Graph pooling ("sum", "mean", "max")
|
|
- `eps`: Learnable or fixed epsilon
|
|
|
|
**Use Cases:**
|
|
- Drug property prediction (BBBP, HIV, etc.)
|
|
- Molecular generation
|
|
- Reaction prediction
|
|
|
|
### RGCN (Relational Graph Convolutional Network)
|
|
|
|
**Type:** Multi-relational message passing
|
|
**Paper:** Modeling Relational Data with Graph Convolutional Networks (Schlichtkrull et al., 2018)
|
|
|
|
**Characteristics:**
|
|
- Handles multiple edge/relation types
|
|
- Relation-specific weight matrices
|
|
- Basis decomposition for parameter efficiency
|
|
- Essential for knowledge graphs
|
|
|
|
**Best For:**
|
|
- Knowledge graph reasoning
|
|
- Heterogeneous molecular graphs
|
|
- Multi-relational data
|
|
|
|
**Parameters:**
|
|
- `num_relation`: Number of relation types
|
|
- `hidden_dims`: Layer dimensions
|
|
- `num_bases`: Basis decomposition (reduce parameters)
|
|
|
|
**Use Cases:**
|
|
- Knowledge graph completion
|
|
- Retrosynthesis (different bond types)
|
|
- Protein interaction networks
|
|
|
|
### MPNN (Message Passing Neural Network)
|
|
|
|
**Type:** General message passing framework
|
|
**Paper:** Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017)
|
|
|
|
**Characteristics:**
|
|
- Flexible message and update functions
|
|
- Edge features in message computation
|
|
- GRU updates for node hidden states
|
|
- Set2Set readout for graph representation
|
|
|
|
**Best For:**
|
|
- Quantum chemistry predictions
|
|
- Tasks with important edge information
|
|
- When node states evolve over multiple iterations
|
|
|
|
**Parameters:**
|
|
- `input_dim`, `hidden_dim`: Feature dimensions
|
|
- `edge_input_dim`: Edge feature dimension
|
|
- `num_layer`: Message passing iterations
|
|
- `num_mlp_layer`: MLP layers in message function
|
|
|
|
**Use Cases:**
|
|
- QM9 quantum property prediction
|
|
- Molecular dynamics
|
|
- 3D conformation-aware tasks
|
|
|
|
### SchNet (Continuous-Filter Convolutional Network)
|
|
|
|
**Type:** 3D geometry-aware convolution
|
|
**Paper:** SchNet: A continuous-filter convolutional neural network (Schütt et al., 2017)
|
|
|
|
**Characteristics:**
|
|
- Operates on 3D atomic coordinates
|
|
- Continuous filter convolutions
|
|
- Rotation and translation invariant
|
|
- Excellent for quantum chemistry
|
|
|
|
**Best For:**
|
|
- 3D molecular structure tasks
|
|
- Quantum property prediction
|
|
- Protein structure analysis
|
|
- Energy and force prediction
|
|
|
|
**Parameters:**
|
|
- `input_dim`: Atom features
|
|
- `hidden_dims`: Layer dimensions
|
|
- `num_gaussian`: RBF basis functions for distances
|
|
- `cutoff`: Interaction cutoff distance
|
|
|
|
**Use Cases:**
|
|
- QM9 property prediction
|
|
- Molecular dynamics simulations
|
|
- Protein-ligand binding with structures
|
|
- Crystal property prediction
|
|
|
|
### ChebNet (Chebyshev Spectral CNN)
|
|
|
|
**Type:** Spectral convolution
|
|
**Paper:** Convolutional Neural Networks on Graphs (Defferrard et al., 2016)
|
|
|
|
**Characteristics:**
|
|
- Spectral graph convolution
|
|
- Chebyshev polynomial approximation
|
|
- Captures global graph structure
|
|
- Computationally efficient
|
|
|
|
**Best For:**
|
|
- Tasks requiring global information
|
|
- When graph Laplacian is informative
|
|
- Theoretical analysis
|
|
|
|
**Parameters:**
|
|
- `input_dim`, `hidden_dims`: Dimensions
|
|
- `num_cheb`: Order of Chebyshev polynomial
|
|
|
|
**Use Cases:**
|
|
- Citation network classification
|
|
- Brain network analysis
|
|
- Signal processing on graphs
|
|
|
|
### NFP (Neural Fingerprint)
|
|
|
|
**Type:** Molecular fingerprint learning
|
|
**Paper:** Convolutional Networks on Graphs for Learning Molecular Fingerprints (Duvenaud et al., 2015)
|
|
|
|
**Characteristics:**
|
|
- Learns differentiable molecular fingerprints
|
|
- Alternative to hand-crafted fingerprints (ECFP)
|
|
- Circular convolutions like ECFP
|
|
- Interpretable learned features
|
|
|
|
**Best For:**
|
|
- Molecular similarity learning
|
|
- Property prediction with limited data
|
|
- When interpretability is important
|
|
|
|
**Parameters:**
|
|
- `input_dim`, `output_dim`: Feature dimensions
|
|
- `hidden_dims`: Layer dimensions
|
|
- `num_layer`: Circular convolution depth
|
|
|
|
**Use Cases:**
|
|
- Virtual screening
|
|
- Molecular similarity search
|
|
- QSAR modeling
|
|
|
|
## Protein-Specific Models
|
|
|
|
### GearNet (Geometry-Aware Relational Graph Network)
|
|
|
|
**Type:** Protein structure encoder
|
|
**Paper:** Protein Representation Learning by Geometric Structure Pretraining (Zhang et al., 2023)
|
|
|
|
**Characteristics:**
|
|
- Incorporates 3D geometric information
|
|
- Multiple edge types (sequential, spatial, KNN)
|
|
- Designed specifically for proteins
|
|
- State-of-the-art on protein tasks
|
|
|
|
**Best For:**
|
|
- Protein structure prediction
|
|
- Protein function prediction
|
|
- Protein-protein interaction
|
|
- Any task with protein 3D structures
|
|
|
|
**Parameters:**
|
|
- `input_dim`: Residue features
|
|
- `hidden_dims`: Layer dimensions
|
|
- `num_relation`: Edge types (sequence, radius, KNN)
|
|
- `edge_input_dim`: Geometric features (distances, angles)
|
|
- `batch_norm`: Typically true
|
|
|
|
**Use Cases:**
|
|
- Enzyme function prediction (EnzymeCommission)
|
|
- Protein fold recognition
|
|
- Contact prediction
|
|
- Binding site identification
|
|
|
|
### ESM (Evolutionary Scale Modeling)
|
|
|
|
**Type:** Protein language model (transformer)
|
|
**Paper:** Biological structure and function emerge from scaling unsupervised learning (Rives et al., 2021)
|
|
|
|
**Characteristics:**
|
|
- Pre-trained on 250M+ protein sequences
|
|
- Captures evolutionary and structural information
|
|
- Transformer architecture
|
|
- Transfer learning for downstream tasks
|
|
|
|
**Best For:**
|
|
- Any sequence-based protein task
|
|
- When no structure available
|
|
- Transfer learning with limited data
|
|
|
|
**Variants:**
|
|
- ESM-1b: 650M parameters
|
|
- ESM-2: Multiple sizes (8M to 15B parameters)
|
|
|
|
**Use Cases:**
|
|
- Protein function prediction
|
|
- Variant effect prediction
|
|
- Protein design
|
|
- Structure prediction (ESMFold)
|
|
|
|
### ProteinBERT
|
|
|
|
**Type:** Masked language model for proteins
|
|
|
|
**Characteristics:**
|
|
- BERT-style pre-training
|
|
- Masked amino acid prediction
|
|
- Bidirectional context
|
|
- Good for sequence-based tasks
|
|
|
|
**Use Cases:**
|
|
- Function annotation
|
|
- Subcellular localization
|
|
- Stability prediction
|
|
|
|
### ProteinCNN / ProteinResNet
|
|
|
|
**Type:** Convolutional networks for sequences
|
|
|
|
**Characteristics:**
|
|
- 1D convolutions on sequences
|
|
- Local pattern recognition
|
|
- Faster than transformers
|
|
- Good for motif detection
|
|
|
|
**Use Cases:**
|
|
- Binding site prediction
|
|
- Secondary structure prediction
|
|
- Domain identification
|
|
|
|
### ProteinLSTM
|
|
|
|
**Type:** Recurrent network for sequences
|
|
|
|
**Characteristics:**
|
|
- Bidirectional LSTM
|
|
- Captures long-range dependencies
|
|
- Sequential processing
|
|
- Good baseline for sequence tasks
|
|
|
|
**Use Cases:**
|
|
- Order prediction
|
|
- Sequential annotation
|
|
- Time-series protein data
|
|
|
|
## Knowledge Graph Models
|
|
|
|
### TransE (Translation Embedding)
|
|
|
|
**Type:** Translation-based embedding
|
|
**Paper:** Translating Embeddings for Modeling Multi-relational Data (Bordes et al., 2013)
|
|
|
|
**Characteristics:**
|
|
- h + r ≈ t (head + relation ≈ tail)
|
|
- Simple and interpretable
|
|
- Works well for 1-to-1 relations
|
|
- Memory efficient
|
|
|
|
**Best For:**
|
|
- Large knowledge graphs
|
|
- Initial experiments
|
|
- Interpretable embeddings
|
|
|
|
**Parameters:**
|
|
- `num_entity`, `num_relation`: Graph size
|
|
- `embedding_dim`: Embedding dimensions (typically 50-500)
|
|
|
|
### RotatE (Rotation Embedding)
|
|
|
|
**Type:** Rotation in complex space
|
|
**Paper:** RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (Sun et al., 2019)
|
|
|
|
**Characteristics:**
|
|
- Relations as rotations in complex space
|
|
- Handles symmetric, antisymmetric, inverse, composition
|
|
- State-of-the-art on many benchmarks
|
|
|
|
**Best For:**
|
|
- Most knowledge graph tasks
|
|
- Complex relation patterns
|
|
- When accuracy is critical
|
|
|
|
**Parameters:**
|
|
- `num_entity`, `num_relation`: Graph size
|
|
- `embedding_dim`: Must be even (complex embeddings)
|
|
- `max_score`: Score clipping value
|
|
|
|
### DistMult
|
|
|
|
**Type:** Bilinear model
|
|
|
|
**Characteristics:**
|
|
- Symmetric relation modeling
|
|
- Fast and efficient
|
|
- Cannot model antisymmetric relations
|
|
|
|
**Best For:**
|
|
- Symmetric relations (e.g., "similar to")
|
|
- When speed is critical
|
|
- Large-scale graphs
|
|
|
|
### ComplEx
|
|
|
|
**Type:** Complex-valued embeddings
|
|
|
|
**Characteristics:**
|
|
- Handles asymmetric and symmetric relations
|
|
- Better than DistMult for most graphs
|
|
- Good balance of expressiveness and efficiency
|
|
|
|
**Best For:**
|
|
- General knowledge graph completion
|
|
- Mixed relation types
|
|
- When RotatE is too complex
|
|
|
|
### SimplE
|
|
|
|
**Type:** Enhanced embedding model
|
|
|
|
**Characteristics:**
|
|
- Two embeddings per entity (canonical + inverse)
|
|
- Fully expressive
|
|
- Slightly more parameters than ComplEx
|
|
|
|
**Best For:**
|
|
- When full expressiveness needed
|
|
- Inverse relations are important
|
|
|
|
## Generative Models
|
|
|
|
### GraphAutoregressiveFlow
|
|
|
|
**Type:** Normalizing flow for molecules
|
|
|
|
**Characteristics:**
|
|
- Exact likelihood computation
|
|
- Invertible transformations
|
|
- Stable training (no adversarial)
|
|
- Conditional generation support
|
|
|
|
**Best For:**
|
|
- Molecular generation
|
|
- Density estimation
|
|
- Interpolation between molecules
|
|
|
|
**Parameters:**
|
|
- `input_dim`: Atom features
|
|
- `hidden_dims`: Coupling layers
|
|
- `num_flow`: Number of flow transformations
|
|
|
|
**Use Cases:**
|
|
- De novo drug design
|
|
- Chemical space exploration
|
|
- Property-targeted generation
|
|
|
|
## Pre-training Models
|
|
|
|
### InfoGraph
|
|
|
|
**Type:** Contrastive learning
|
|
|
|
**Characteristics:**
|
|
- Maximizes mutual information
|
|
- Graph-level and node-level contrast
|
|
- Unsupervised pre-training
|
|
- Good for small datasets
|
|
|
|
**Use Cases:**
|
|
- Pre-train molecular encoders
|
|
- Few-shot learning
|
|
- Transfer learning
|
|
|
|
### MultiviewContrast
|
|
|
|
**Type:** Multi-view contrastive learning for proteins
|
|
|
|
**Characteristics:**
|
|
- Contrasts different views of proteins
|
|
- Geometric pre-training
|
|
- Uses 3D structure information
|
|
- Excellent for protein models
|
|
|
|
**Use Cases:**
|
|
- Pre-train GearNet on protein structures
|
|
- Transfer to property prediction
|
|
- Limited labeled data scenarios
|
|
|
|
## Model Selection Guide
|
|
|
|
### By Task Type
|
|
|
|
**Molecular Property Prediction:**
|
|
1. GIN (first choice)
|
|
2. GAT (interpretability)
|
|
3. SchNet (3D available)
|
|
|
|
**Protein Tasks:**
|
|
1. ESM (sequence only)
|
|
2. GearNet (structure available)
|
|
3. ProteinBERT (sequence, lighter than ESM)
|
|
|
|
**Knowledge Graphs:**
|
|
1. RotatE (best performance)
|
|
2. ComplEx (good balance)
|
|
3. TransE (large graphs, efficiency)
|
|
|
|
**Molecular Generation:**
|
|
1. GraphAutoregressiveFlow (exact likelihood)
|
|
2. GCPN with GIN backbone (property optimization)
|
|
|
|
**Retrosynthesis:**
|
|
1. GIN (synthon completion)
|
|
2. RGCN (center identification with bond types)
|
|
|
|
### By Dataset Size
|
|
|
|
**Small (< 1k):**
|
|
- Use pre-trained models (ESM for proteins)
|
|
- Simpler architectures (GCN, ProteinCNN)
|
|
- Heavy regularization
|
|
|
|
**Medium (1k-100k):**
|
|
- GIN for molecules
|
|
- GAT for interpretability
|
|
- Standard training
|
|
|
|
**Large (> 100k):**
|
|
- Any model works
|
|
- Deeper architectures
|
|
- Can train from scratch
|
|
|
|
### By Computational Budget
|
|
|
|
**Low:**
|
|
- GCN (simplest)
|
|
- DistMult (KG)
|
|
- ProteinLSTM
|
|
|
|
**Medium:**
|
|
- GIN
|
|
- GAT
|
|
- ComplEx
|
|
|
|
**High:**
|
|
- ESM (large)
|
|
- SchNet (3D)
|
|
- RotatE with high dim
|
|
|
|
## Implementation Tips
|
|
|
|
1. **Start Simple**: Begin with GCN or GIN baseline
|
|
2. **Use Pre-trained**: ESM for proteins, InfoGraph for molecules
|
|
3. **Tune Depth**: 3-5 layers typically sufficient
|
|
4. **Batch Normalization**: Usually helps (except KG embeddings)
|
|
5. **Residual Connections**: Important for deep networks
|
|
6. **Readout Function**: "mean" usually works well
|
|
7. **Edge Features**: Include when available (bonds, distances)
|
|
8. **Regularization**: Dropout, weight decay, early stopping
|