Files
gh-k-dense-ai-claude-scient…/skills/torchdrug/references/models_architectures.md
2025-11-30 08:30:10 +08:00

542 lines
13 KiB
Markdown

# Models and Architectures
## Overview
TorchDrug provides a comprehensive collection of pre-built model architectures for various graph-based learning tasks. This reference catalogs all available models with their characteristics, use cases, and implementation details.
## Graph Neural Networks
### GCN (Graph Convolutional Network)
**Type:** Spatial message passing
**Paper:** Semi-Supervised Classification with Graph Convolutional Networks (Kipf & Welling, 2017)
**Characteristics:**
- Simple and efficient aggregation
- Normalized adjacency matrix convolution
- Works well for homophilic graphs
- Good baseline for many tasks
**Best For:**
- Initial experiments and baselines
- When computational efficiency is important
- Graphs with clear local structure
**Parameters:**
- `input_dim`: Node feature dimension
- `hidden_dims`: List of hidden layer dimensions
- `edge_input_dim`: Edge feature dimension (optional)
- `batch_norm`: Apply batch normalization
- `activation`: Activation function (relu, elu, etc.)
- `dropout`: Dropout rate
**Use Cases:**
- Molecular property prediction
- Citation network classification
- Social network analysis
### GAT (Graph Attention Network)
**Type:** Attention-based message passing
**Paper:** Graph Attention Networks (Veličković et al., 2018)
**Characteristics:**
- Learns attention weights for neighbors
- Different importance for different neighbors
- Multi-head attention for robustness
- Handles varying node degrees naturally
**Best For:**
- When neighbor importance varies
- Heterogeneous graphs
- Interpretable predictions
**Parameters:**
- `input_dim`, `hidden_dims`: Standard dimensions
- `num_heads`: Number of attention heads
- `negative_slope`: LeakyReLU slope
- `concat`: Concatenate or average multi-head outputs
**Use Cases:**
- Protein-protein interaction prediction
- Molecule generation with attention to reactive sites
- Knowledge graph reasoning with relation importance
### GIN (Graph Isomorphism Network)
**Type:** Maximally powerful message passing
**Paper:** How Powerful are Graph Neural Networks? (Xu et al., 2019)
**Characteristics:**
- Theoretically most expressive GNN architecture
- Injective aggregation function
- Can distinguish graph structures GCN cannot
- Often best performance on molecular tasks
**Best For:**
- Molecular property prediction (state-of-the-art)
- Tasks requiring structural discrimination
- Graph classification
**Parameters:**
- `input_dim`, `hidden_dims`: Standard dimensions
- `edge_input_dim`: Include edge features
- `batch_norm`: Typically use true
- `readout`: Graph pooling ("sum", "mean", "max")
- `eps`: Learnable or fixed epsilon
**Use Cases:**
- Drug property prediction (BBBP, HIV, etc.)
- Molecular generation
- Reaction prediction
### RGCN (Relational Graph Convolutional Network)
**Type:** Multi-relational message passing
**Paper:** Modeling Relational Data with Graph Convolutional Networks (Schlichtkrull et al., 2018)
**Characteristics:**
- Handles multiple edge/relation types
- Relation-specific weight matrices
- Basis decomposition for parameter efficiency
- Essential for knowledge graphs
**Best For:**
- Knowledge graph reasoning
- Heterogeneous molecular graphs
- Multi-relational data
**Parameters:**
- `num_relation`: Number of relation types
- `hidden_dims`: Layer dimensions
- `num_bases`: Basis decomposition (reduce parameters)
**Use Cases:**
- Knowledge graph completion
- Retrosynthesis (different bond types)
- Protein interaction networks
### MPNN (Message Passing Neural Network)
**Type:** General message passing framework
**Paper:** Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017)
**Characteristics:**
- Flexible message and update functions
- Edge features in message computation
- GRU updates for node hidden states
- Set2Set readout for graph representation
**Best For:**
- Quantum chemistry predictions
- Tasks with important edge information
- When node states evolve over multiple iterations
**Parameters:**
- `input_dim`, `hidden_dim`: Feature dimensions
- `edge_input_dim`: Edge feature dimension
- `num_layer`: Message passing iterations
- `num_mlp_layer`: MLP layers in message function
**Use Cases:**
- QM9 quantum property prediction
- Molecular dynamics
- 3D conformation-aware tasks
### SchNet (Continuous-Filter Convolutional Network)
**Type:** 3D geometry-aware convolution
**Paper:** SchNet: A continuous-filter convolutional neural network (Schütt et al., 2017)
**Characteristics:**
- Operates on 3D atomic coordinates
- Continuous filter convolutions
- Rotation and translation invariant
- Excellent for quantum chemistry
**Best For:**
- 3D molecular structure tasks
- Quantum property prediction
- Protein structure analysis
- Energy and force prediction
**Parameters:**
- `input_dim`: Atom features
- `hidden_dims`: Layer dimensions
- `num_gaussian`: RBF basis functions for distances
- `cutoff`: Interaction cutoff distance
**Use Cases:**
- QM9 property prediction
- Molecular dynamics simulations
- Protein-ligand binding with structures
- Crystal property prediction
### ChebNet (Chebyshev Spectral CNN)
**Type:** Spectral convolution
**Paper:** Convolutional Neural Networks on Graphs (Defferrard et al., 2016)
**Characteristics:**
- Spectral graph convolution
- Chebyshev polynomial approximation
- Captures global graph structure
- Computationally efficient
**Best For:**
- Tasks requiring global information
- When graph Laplacian is informative
- Theoretical analysis
**Parameters:**
- `input_dim`, `hidden_dims`: Dimensions
- `num_cheb`: Order of Chebyshev polynomial
**Use Cases:**
- Citation network classification
- Brain network analysis
- Signal processing on graphs
### NFP (Neural Fingerprint)
**Type:** Molecular fingerprint learning
**Paper:** Convolutional Networks on Graphs for Learning Molecular Fingerprints (Duvenaud et al., 2015)
**Characteristics:**
- Learns differentiable molecular fingerprints
- Alternative to hand-crafted fingerprints (ECFP)
- Circular convolutions like ECFP
- Interpretable learned features
**Best For:**
- Molecular similarity learning
- Property prediction with limited data
- When interpretability is important
**Parameters:**
- `input_dim`, `output_dim`: Feature dimensions
- `hidden_dims`: Layer dimensions
- `num_layer`: Circular convolution depth
**Use Cases:**
- Virtual screening
- Molecular similarity search
- QSAR modeling
## Protein-Specific Models
### GearNet (Geometry-Aware Relational Graph Network)
**Type:** Protein structure encoder
**Paper:** Protein Representation Learning by Geometric Structure Pretraining (Zhang et al., 2023)
**Characteristics:**
- Incorporates 3D geometric information
- Multiple edge types (sequential, spatial, KNN)
- Designed specifically for proteins
- State-of-the-art on protein tasks
**Best For:**
- Protein structure prediction
- Protein function prediction
- Protein-protein interaction
- Any task with protein 3D structures
**Parameters:**
- `input_dim`: Residue features
- `hidden_dims`: Layer dimensions
- `num_relation`: Edge types (sequence, radius, KNN)
- `edge_input_dim`: Geometric features (distances, angles)
- `batch_norm`: Typically true
**Use Cases:**
- Enzyme function prediction (EnzymeCommission)
- Protein fold recognition
- Contact prediction
- Binding site identification
### ESM (Evolutionary Scale Modeling)
**Type:** Protein language model (transformer)
**Paper:** Biological structure and function emerge from scaling unsupervised learning (Rives et al., 2021)
**Characteristics:**
- Pre-trained on 250M+ protein sequences
- Captures evolutionary and structural information
- Transformer architecture
- Transfer learning for downstream tasks
**Best For:**
- Any sequence-based protein task
- When no structure available
- Transfer learning with limited data
**Variants:**
- ESM-1b: 650M parameters
- ESM-2: Multiple sizes (8M to 15B parameters)
**Use Cases:**
- Protein function prediction
- Variant effect prediction
- Protein design
- Structure prediction (ESMFold)
### ProteinBERT
**Type:** Masked language model for proteins
**Characteristics:**
- BERT-style pre-training
- Masked amino acid prediction
- Bidirectional context
- Good for sequence-based tasks
**Use Cases:**
- Function annotation
- Subcellular localization
- Stability prediction
### ProteinCNN / ProteinResNet
**Type:** Convolutional networks for sequences
**Characteristics:**
- 1D convolutions on sequences
- Local pattern recognition
- Faster than transformers
- Good for motif detection
**Use Cases:**
- Binding site prediction
- Secondary structure prediction
- Domain identification
### ProteinLSTM
**Type:** Recurrent network for sequences
**Characteristics:**
- Bidirectional LSTM
- Captures long-range dependencies
- Sequential processing
- Good baseline for sequence tasks
**Use Cases:**
- Order prediction
- Sequential annotation
- Time-series protein data
## Knowledge Graph Models
### TransE (Translation Embedding)
**Type:** Translation-based embedding
**Paper:** Translating Embeddings for Modeling Multi-relational Data (Bordes et al., 2013)
**Characteristics:**
- h + r ≈ t (head + relation ≈ tail)
- Simple and interpretable
- Works well for 1-to-1 relations
- Memory efficient
**Best For:**
- Large knowledge graphs
- Initial experiments
- Interpretable embeddings
**Parameters:**
- `num_entity`, `num_relation`: Graph size
- `embedding_dim`: Embedding dimensions (typically 50-500)
### RotatE (Rotation Embedding)
**Type:** Rotation in complex space
**Paper:** RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (Sun et al., 2019)
**Characteristics:**
- Relations as rotations in complex space
- Handles symmetric, antisymmetric, inverse, composition
- State-of-the-art on many benchmarks
**Best For:**
- Most knowledge graph tasks
- Complex relation patterns
- When accuracy is critical
**Parameters:**
- `num_entity`, `num_relation`: Graph size
- `embedding_dim`: Must be even (complex embeddings)
- `max_score`: Score clipping value
### DistMult
**Type:** Bilinear model
**Characteristics:**
- Symmetric relation modeling
- Fast and efficient
- Cannot model antisymmetric relations
**Best For:**
- Symmetric relations (e.g., "similar to")
- When speed is critical
- Large-scale graphs
### ComplEx
**Type:** Complex-valued embeddings
**Characteristics:**
- Handles asymmetric and symmetric relations
- Better than DistMult for most graphs
- Good balance of expressiveness and efficiency
**Best For:**
- General knowledge graph completion
- Mixed relation types
- When RotatE is too complex
### SimplE
**Type:** Enhanced embedding model
**Characteristics:**
- Two embeddings per entity (canonical + inverse)
- Fully expressive
- Slightly more parameters than ComplEx
**Best For:**
- When full expressiveness needed
- Inverse relations are important
## Generative Models
### GraphAutoregressiveFlow
**Type:** Normalizing flow for molecules
**Characteristics:**
- Exact likelihood computation
- Invertible transformations
- Stable training (no adversarial)
- Conditional generation support
**Best For:**
- Molecular generation
- Density estimation
- Interpolation between molecules
**Parameters:**
- `input_dim`: Atom features
- `hidden_dims`: Coupling layers
- `num_flow`: Number of flow transformations
**Use Cases:**
- De novo drug design
- Chemical space exploration
- Property-targeted generation
## Pre-training Models
### InfoGraph
**Type:** Contrastive learning
**Characteristics:**
- Maximizes mutual information
- Graph-level and node-level contrast
- Unsupervised pre-training
- Good for small datasets
**Use Cases:**
- Pre-train molecular encoders
- Few-shot learning
- Transfer learning
### MultiviewContrast
**Type:** Multi-view contrastive learning for proteins
**Characteristics:**
- Contrasts different views of proteins
- Geometric pre-training
- Uses 3D structure information
- Excellent for protein models
**Use Cases:**
- Pre-train GearNet on protein structures
- Transfer to property prediction
- Limited labeled data scenarios
## Model Selection Guide
### By Task Type
**Molecular Property Prediction:**
1. GIN (first choice)
2. GAT (interpretability)
3. SchNet (3D available)
**Protein Tasks:**
1. ESM (sequence only)
2. GearNet (structure available)
3. ProteinBERT (sequence, lighter than ESM)
**Knowledge Graphs:**
1. RotatE (best performance)
2. ComplEx (good balance)
3. TransE (large graphs, efficiency)
**Molecular Generation:**
1. GraphAutoregressiveFlow (exact likelihood)
2. GCPN with GIN backbone (property optimization)
**Retrosynthesis:**
1. GIN (synthon completion)
2. RGCN (center identification with bond types)
### By Dataset Size
**Small (< 1k):**
- Use pre-trained models (ESM for proteins)
- Simpler architectures (GCN, ProteinCNN)
- Heavy regularization
**Medium (1k-100k):**
- GIN for molecules
- GAT for interpretability
- Standard training
**Large (> 100k):**
- Any model works
- Deeper architectures
- Can train from scratch
### By Computational Budget
**Low:**
- GCN (simplest)
- DistMult (KG)
- ProteinLSTM
**Medium:**
- GIN
- GAT
- ComplEx
**High:**
- ESM (large)
- SchNet (3D)
- RotatE with high dim
## Implementation Tips
1. **Start Simple**: Begin with GCN or GIN baseline
2. **Use Pre-trained**: ESM for proteins, InfoGraph for molecules
3. **Tune Depth**: 3-5 layers typically sufficient
4. **Batch Normalization**: Usually helps (except KG embeddings)
5. **Residual Connections**: Important for deep networks
6. **Readout Function**: "mean" usually works well
7. **Edge Features**: Include when available (bonds, distances)
8. **Regularization**: Dropout, weight decay, early stopping