Initial commit
This commit is contained in:
541
skills/torchdrug/references/models_architectures.md
Normal file
541
skills/torchdrug/references/models_architectures.md
Normal file
@@ -0,0 +1,541 @@
|
||||
# Models and Architectures
|
||||
|
||||
## Overview
|
||||
|
||||
TorchDrug provides a comprehensive collection of pre-built model architectures for various graph-based learning tasks. This reference catalogs all available models with their characteristics, use cases, and implementation details.
|
||||
|
||||
## Graph Neural Networks
|
||||
|
||||
### GCN (Graph Convolutional Network)
|
||||
|
||||
**Type:** Spatial message passing
|
||||
**Paper:** Semi-Supervised Classification with Graph Convolutional Networks (Kipf & Welling, 2017)
|
||||
|
||||
**Characteristics:**
|
||||
- Simple and efficient aggregation
|
||||
- Normalized adjacency matrix convolution
|
||||
- Works well for homophilic graphs
|
||||
- Good baseline for many tasks
|
||||
|
||||
**Best For:**
|
||||
- Initial experiments and baselines
|
||||
- When computational efficiency is important
|
||||
- Graphs with clear local structure
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`: Node feature dimension
|
||||
- `hidden_dims`: List of hidden layer dimensions
|
||||
- `edge_input_dim`: Edge feature dimension (optional)
|
||||
- `batch_norm`: Apply batch normalization
|
||||
- `activation`: Activation function (relu, elu, etc.)
|
||||
- `dropout`: Dropout rate
|
||||
|
||||
**Use Cases:**
|
||||
- Molecular property prediction
|
||||
- Citation network classification
|
||||
- Social network analysis
|
||||
|
||||
### GAT (Graph Attention Network)
|
||||
|
||||
**Type:** Attention-based message passing
|
||||
**Paper:** Graph Attention Networks (Veličković et al., 2018)
|
||||
|
||||
**Characteristics:**
|
||||
- Learns attention weights for neighbors
|
||||
- Different importance for different neighbors
|
||||
- Multi-head attention for robustness
|
||||
- Handles varying node degrees naturally
|
||||
|
||||
**Best For:**
|
||||
- When neighbor importance varies
|
||||
- Heterogeneous graphs
|
||||
- Interpretable predictions
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `hidden_dims`: Standard dimensions
|
||||
- `num_heads`: Number of attention heads
|
||||
- `negative_slope`: LeakyReLU slope
|
||||
- `concat`: Concatenate or average multi-head outputs
|
||||
|
||||
**Use Cases:**
|
||||
- Protein-protein interaction prediction
|
||||
- Molecule generation with attention to reactive sites
|
||||
- Knowledge graph reasoning with relation importance
|
||||
|
||||
### GIN (Graph Isomorphism Network)
|
||||
|
||||
**Type:** Maximally powerful message passing
|
||||
**Paper:** How Powerful are Graph Neural Networks? (Xu et al., 2019)
|
||||
|
||||
**Characteristics:**
|
||||
- Theoretically most expressive GNN architecture
|
||||
- Injective aggregation function
|
||||
- Can distinguish graph structures GCN cannot
|
||||
- Often best performance on molecular tasks
|
||||
|
||||
**Best For:**
|
||||
- Molecular property prediction (state-of-the-art)
|
||||
- Tasks requiring structural discrimination
|
||||
- Graph classification
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `hidden_dims`: Standard dimensions
|
||||
- `edge_input_dim`: Include edge features
|
||||
- `batch_norm`: Typically use true
|
||||
- `readout`: Graph pooling ("sum", "mean", "max")
|
||||
- `eps`: Learnable or fixed epsilon
|
||||
|
||||
**Use Cases:**
|
||||
- Drug property prediction (BBBP, HIV, etc.)
|
||||
- Molecular generation
|
||||
- Reaction prediction
|
||||
|
||||
### RGCN (Relational Graph Convolutional Network)
|
||||
|
||||
**Type:** Multi-relational message passing
|
||||
**Paper:** Modeling Relational Data with Graph Convolutional Networks (Schlichtkrull et al., 2018)
|
||||
|
||||
**Characteristics:**
|
||||
- Handles multiple edge/relation types
|
||||
- Relation-specific weight matrices
|
||||
- Basis decomposition for parameter efficiency
|
||||
- Essential for knowledge graphs
|
||||
|
||||
**Best For:**
|
||||
- Knowledge graph reasoning
|
||||
- Heterogeneous molecular graphs
|
||||
- Multi-relational data
|
||||
|
||||
**Parameters:**
|
||||
- `num_relation`: Number of relation types
|
||||
- `hidden_dims`: Layer dimensions
|
||||
- `num_bases`: Basis decomposition (reduce parameters)
|
||||
|
||||
**Use Cases:**
|
||||
- Knowledge graph completion
|
||||
- Retrosynthesis (different bond types)
|
||||
- Protein interaction networks
|
||||
|
||||
### MPNN (Message Passing Neural Network)
|
||||
|
||||
**Type:** General message passing framework
|
||||
**Paper:** Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017)
|
||||
|
||||
**Characteristics:**
|
||||
- Flexible message and update functions
|
||||
- Edge features in message computation
|
||||
- GRU updates for node hidden states
|
||||
- Set2Set readout for graph representation
|
||||
|
||||
**Best For:**
|
||||
- Quantum chemistry predictions
|
||||
- Tasks with important edge information
|
||||
- When node states evolve over multiple iterations
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `hidden_dim`: Feature dimensions
|
||||
- `edge_input_dim`: Edge feature dimension
|
||||
- `num_layer`: Message passing iterations
|
||||
- `num_mlp_layer`: MLP layers in message function
|
||||
|
||||
**Use Cases:**
|
||||
- QM9 quantum property prediction
|
||||
- Molecular dynamics
|
||||
- 3D conformation-aware tasks
|
||||
|
||||
### SchNet (Continuous-Filter Convolutional Network)
|
||||
|
||||
**Type:** 3D geometry-aware convolution
|
||||
**Paper:** SchNet: A continuous-filter convolutional neural network (Schütt et al., 2017)
|
||||
|
||||
**Characteristics:**
|
||||
- Operates on 3D atomic coordinates
|
||||
- Continuous filter convolutions
|
||||
- Rotation and translation invariant
|
||||
- Excellent for quantum chemistry
|
||||
|
||||
**Best For:**
|
||||
- 3D molecular structure tasks
|
||||
- Quantum property prediction
|
||||
- Protein structure analysis
|
||||
- Energy and force prediction
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`: Atom features
|
||||
- `hidden_dims`: Layer dimensions
|
||||
- `num_gaussian`: RBF basis functions for distances
|
||||
- `cutoff`: Interaction cutoff distance
|
||||
|
||||
**Use Cases:**
|
||||
- QM9 property prediction
|
||||
- Molecular dynamics simulations
|
||||
- Protein-ligand binding with structures
|
||||
- Crystal property prediction
|
||||
|
||||
### ChebNet (Chebyshev Spectral CNN)
|
||||
|
||||
**Type:** Spectral convolution
|
||||
**Paper:** Convolutional Neural Networks on Graphs (Defferrard et al., 2016)
|
||||
|
||||
**Characteristics:**
|
||||
- Spectral graph convolution
|
||||
- Chebyshev polynomial approximation
|
||||
- Captures global graph structure
|
||||
- Computationally efficient
|
||||
|
||||
**Best For:**
|
||||
- Tasks requiring global information
|
||||
- When graph Laplacian is informative
|
||||
- Theoretical analysis
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `hidden_dims`: Dimensions
|
||||
- `num_cheb`: Order of Chebyshev polynomial
|
||||
|
||||
**Use Cases:**
|
||||
- Citation network classification
|
||||
- Brain network analysis
|
||||
- Signal processing on graphs
|
||||
|
||||
### NFP (Neural Fingerprint)
|
||||
|
||||
**Type:** Molecular fingerprint learning
|
||||
**Paper:** Convolutional Networks on Graphs for Learning Molecular Fingerprints (Duvenaud et al., 2015)
|
||||
|
||||
**Characteristics:**
|
||||
- Learns differentiable molecular fingerprints
|
||||
- Alternative to hand-crafted fingerprints (ECFP)
|
||||
- Circular convolutions like ECFP
|
||||
- Interpretable learned features
|
||||
|
||||
**Best For:**
|
||||
- Molecular similarity learning
|
||||
- Property prediction with limited data
|
||||
- When interpretability is important
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `output_dim`: Feature dimensions
|
||||
- `hidden_dims`: Layer dimensions
|
||||
- `num_layer`: Circular convolution depth
|
||||
|
||||
**Use Cases:**
|
||||
- Virtual screening
|
||||
- Molecular similarity search
|
||||
- QSAR modeling
|
||||
|
||||
## Protein-Specific Models
|
||||
|
||||
### GearNet (Geometry-Aware Relational Graph Network)
|
||||
|
||||
**Type:** Protein structure encoder
|
||||
**Paper:** Protein Representation Learning by Geometric Structure Pretraining (Zhang et al., 2023)
|
||||
|
||||
**Characteristics:**
|
||||
- Incorporates 3D geometric information
|
||||
- Multiple edge types (sequential, spatial, KNN)
|
||||
- Designed specifically for proteins
|
||||
- State-of-the-art on protein tasks
|
||||
|
||||
**Best For:**
|
||||
- Protein structure prediction
|
||||
- Protein function prediction
|
||||
- Protein-protein interaction
|
||||
- Any task with protein 3D structures
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`: Residue features
|
||||
- `hidden_dims`: Layer dimensions
|
||||
- `num_relation`: Edge types (sequence, radius, KNN)
|
||||
- `edge_input_dim`: Geometric features (distances, angles)
|
||||
- `batch_norm`: Typically true
|
||||
|
||||
**Use Cases:**
|
||||
- Enzyme function prediction (EnzymeCommission)
|
||||
- Protein fold recognition
|
||||
- Contact prediction
|
||||
- Binding site identification
|
||||
|
||||
### ESM (Evolutionary Scale Modeling)
|
||||
|
||||
**Type:** Protein language model (transformer)
|
||||
**Paper:** Biological structure and function emerge from scaling unsupervised learning (Rives et al., 2021)
|
||||
|
||||
**Characteristics:**
|
||||
- Pre-trained on 250M+ protein sequences
|
||||
- Captures evolutionary and structural information
|
||||
- Transformer architecture
|
||||
- Transfer learning for downstream tasks
|
||||
|
||||
**Best For:**
|
||||
- Any sequence-based protein task
|
||||
- When no structure available
|
||||
- Transfer learning with limited data
|
||||
|
||||
**Variants:**
|
||||
- ESM-1b: 650M parameters
|
||||
- ESM-2: Multiple sizes (8M to 15B parameters)
|
||||
|
||||
**Use Cases:**
|
||||
- Protein function prediction
|
||||
- Variant effect prediction
|
||||
- Protein design
|
||||
- Structure prediction (ESMFold)
|
||||
|
||||
### ProteinBERT
|
||||
|
||||
**Type:** Masked language model for proteins
|
||||
|
||||
**Characteristics:**
|
||||
- BERT-style pre-training
|
||||
- Masked amino acid prediction
|
||||
- Bidirectional context
|
||||
- Good for sequence-based tasks
|
||||
|
||||
**Use Cases:**
|
||||
- Function annotation
|
||||
- Subcellular localization
|
||||
- Stability prediction
|
||||
|
||||
### ProteinCNN / ProteinResNet
|
||||
|
||||
**Type:** Convolutional networks for sequences
|
||||
|
||||
**Characteristics:**
|
||||
- 1D convolutions on sequences
|
||||
- Local pattern recognition
|
||||
- Faster than transformers
|
||||
- Good for motif detection
|
||||
|
||||
**Use Cases:**
|
||||
- Binding site prediction
|
||||
- Secondary structure prediction
|
||||
- Domain identification
|
||||
|
||||
### ProteinLSTM
|
||||
|
||||
**Type:** Recurrent network for sequences
|
||||
|
||||
**Characteristics:**
|
||||
- Bidirectional LSTM
|
||||
- Captures long-range dependencies
|
||||
- Sequential processing
|
||||
- Good baseline for sequence tasks
|
||||
|
||||
**Use Cases:**
|
||||
- Order prediction
|
||||
- Sequential annotation
|
||||
- Time-series protein data
|
||||
|
||||
## Knowledge Graph Models
|
||||
|
||||
### TransE (Translation Embedding)
|
||||
|
||||
**Type:** Translation-based embedding
|
||||
**Paper:** Translating Embeddings for Modeling Multi-relational Data (Bordes et al., 2013)
|
||||
|
||||
**Characteristics:**
|
||||
- h + r ≈ t (head + relation ≈ tail)
|
||||
- Simple and interpretable
|
||||
- Works well for 1-to-1 relations
|
||||
- Memory efficient
|
||||
|
||||
**Best For:**
|
||||
- Large knowledge graphs
|
||||
- Initial experiments
|
||||
- Interpretable embeddings
|
||||
|
||||
**Parameters:**
|
||||
- `num_entity`, `num_relation`: Graph size
|
||||
- `embedding_dim`: Embedding dimensions (typically 50-500)
|
||||
|
||||
### RotatE (Rotation Embedding)
|
||||
|
||||
**Type:** Rotation in complex space
|
||||
**Paper:** RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (Sun et al., 2019)
|
||||
|
||||
**Characteristics:**
|
||||
- Relations as rotations in complex space
|
||||
- Handles symmetric, antisymmetric, inverse, composition
|
||||
- State-of-the-art on many benchmarks
|
||||
|
||||
**Best For:**
|
||||
- Most knowledge graph tasks
|
||||
- Complex relation patterns
|
||||
- When accuracy is critical
|
||||
|
||||
**Parameters:**
|
||||
- `num_entity`, `num_relation`: Graph size
|
||||
- `embedding_dim`: Must be even (complex embeddings)
|
||||
- `max_score`: Score clipping value
|
||||
|
||||
### DistMult
|
||||
|
||||
**Type:** Bilinear model
|
||||
|
||||
**Characteristics:**
|
||||
- Symmetric relation modeling
|
||||
- Fast and efficient
|
||||
- Cannot model antisymmetric relations
|
||||
|
||||
**Best For:**
|
||||
- Symmetric relations (e.g., "similar to")
|
||||
- When speed is critical
|
||||
- Large-scale graphs
|
||||
|
||||
### ComplEx
|
||||
|
||||
**Type:** Complex-valued embeddings
|
||||
|
||||
**Characteristics:**
|
||||
- Handles asymmetric and symmetric relations
|
||||
- Better than DistMult for most graphs
|
||||
- Good balance of expressiveness and efficiency
|
||||
|
||||
**Best For:**
|
||||
- General knowledge graph completion
|
||||
- Mixed relation types
|
||||
- When RotatE is too complex
|
||||
|
||||
### SimplE
|
||||
|
||||
**Type:** Enhanced embedding model
|
||||
|
||||
**Characteristics:**
|
||||
- Two embeddings per entity (canonical + inverse)
|
||||
- Fully expressive
|
||||
- Slightly more parameters than ComplEx
|
||||
|
||||
**Best For:**
|
||||
- When full expressiveness needed
|
||||
- Inverse relations are important
|
||||
|
||||
## Generative Models
|
||||
|
||||
### GraphAutoregressiveFlow
|
||||
|
||||
**Type:** Normalizing flow for molecules
|
||||
|
||||
**Characteristics:**
|
||||
- Exact likelihood computation
|
||||
- Invertible transformations
|
||||
- Stable training (no adversarial)
|
||||
- Conditional generation support
|
||||
|
||||
**Best For:**
|
||||
- Molecular generation
|
||||
- Density estimation
|
||||
- Interpolation between molecules
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`: Atom features
|
||||
- `hidden_dims`: Coupling layers
|
||||
- `num_flow`: Number of flow transformations
|
||||
|
||||
**Use Cases:**
|
||||
- De novo drug design
|
||||
- Chemical space exploration
|
||||
- Property-targeted generation
|
||||
|
||||
## Pre-training Models
|
||||
|
||||
### InfoGraph
|
||||
|
||||
**Type:** Contrastive learning
|
||||
|
||||
**Characteristics:**
|
||||
- Maximizes mutual information
|
||||
- Graph-level and node-level contrast
|
||||
- Unsupervised pre-training
|
||||
- Good for small datasets
|
||||
|
||||
**Use Cases:**
|
||||
- Pre-train molecular encoders
|
||||
- Few-shot learning
|
||||
- Transfer learning
|
||||
|
||||
### MultiviewContrast
|
||||
|
||||
**Type:** Multi-view contrastive learning for proteins
|
||||
|
||||
**Characteristics:**
|
||||
- Contrasts different views of proteins
|
||||
- Geometric pre-training
|
||||
- Uses 3D structure information
|
||||
- Excellent for protein models
|
||||
|
||||
**Use Cases:**
|
||||
- Pre-train GearNet on protein structures
|
||||
- Transfer to property prediction
|
||||
- Limited labeled data scenarios
|
||||
|
||||
## Model Selection Guide
|
||||
|
||||
### By Task Type
|
||||
|
||||
**Molecular Property Prediction:**
|
||||
1. GIN (first choice)
|
||||
2. GAT (interpretability)
|
||||
3. SchNet (3D available)
|
||||
|
||||
**Protein Tasks:**
|
||||
1. ESM (sequence only)
|
||||
2. GearNet (structure available)
|
||||
3. ProteinBERT (sequence, lighter than ESM)
|
||||
|
||||
**Knowledge Graphs:**
|
||||
1. RotatE (best performance)
|
||||
2. ComplEx (good balance)
|
||||
3. TransE (large graphs, efficiency)
|
||||
|
||||
**Molecular Generation:**
|
||||
1. GraphAutoregressiveFlow (exact likelihood)
|
||||
2. GCPN with GIN backbone (property optimization)
|
||||
|
||||
**Retrosynthesis:**
|
||||
1. GIN (synthon completion)
|
||||
2. RGCN (center identification with bond types)
|
||||
|
||||
### By Dataset Size
|
||||
|
||||
**Small (< 1k):**
|
||||
- Use pre-trained models (ESM for proteins)
|
||||
- Simpler architectures (GCN, ProteinCNN)
|
||||
- Heavy regularization
|
||||
|
||||
**Medium (1k-100k):**
|
||||
- GIN for molecules
|
||||
- GAT for interpretability
|
||||
- Standard training
|
||||
|
||||
**Large (> 100k):**
|
||||
- Any model works
|
||||
- Deeper architectures
|
||||
- Can train from scratch
|
||||
|
||||
### By Computational Budget
|
||||
|
||||
**Low:**
|
||||
- GCN (simplest)
|
||||
- DistMult (KG)
|
||||
- ProteinLSTM
|
||||
|
||||
**Medium:**
|
||||
- GIN
|
||||
- GAT
|
||||
- ComplEx
|
||||
|
||||
**High:**
|
||||
- ESM (large)
|
||||
- SchNet (3D)
|
||||
- RotatE with high dim
|
||||
|
||||
## Implementation Tips
|
||||
|
||||
1. **Start Simple**: Begin with GCN or GIN baseline
|
||||
2. **Use Pre-trained**: ESM for proteins, InfoGraph for molecules
|
||||
3. **Tune Depth**: 3-5 layers typically sufficient
|
||||
4. **Batch Normalization**: Usually helps (except KG embeddings)
|
||||
5. **Residual Connections**: Important for deep networks
|
||||
6. **Readout Function**: "mean" usually works well
|
||||
7. **Edge Features**: Include when available (bonds, distances)
|
||||
8. **Regularization**: Dropout, weight decay, early stopping
|
||||
Reference in New Issue
Block a user