# Models and Architectures ## Overview TorchDrug provides a comprehensive collection of pre-built model architectures for various graph-based learning tasks. This reference catalogs all available models with their characteristics, use cases, and implementation details. ## Graph Neural Networks ### GCN (Graph Convolutional Network) **Type:** Spatial message passing **Paper:** Semi-Supervised Classification with Graph Convolutional Networks (Kipf & Welling, 2017) **Characteristics:** - Simple and efficient aggregation - Normalized adjacency matrix convolution - Works well for homophilic graphs - Good baseline for many tasks **Best For:** - Initial experiments and baselines - When computational efficiency is important - Graphs with clear local structure **Parameters:** - `input_dim`: Node feature dimension - `hidden_dims`: List of hidden layer dimensions - `edge_input_dim`: Edge feature dimension (optional) - `batch_norm`: Apply batch normalization - `activation`: Activation function (relu, elu, etc.) - `dropout`: Dropout rate **Use Cases:** - Molecular property prediction - Citation network classification - Social network analysis ### GAT (Graph Attention Network) **Type:** Attention-based message passing **Paper:** Graph Attention Networks (Veličković et al., 2018) **Characteristics:** - Learns attention weights for neighbors - Different importance for different neighbors - Multi-head attention for robustness - Handles varying node degrees naturally **Best For:** - When neighbor importance varies - Heterogeneous graphs - Interpretable predictions **Parameters:** - `input_dim`, `hidden_dims`: Standard dimensions - `num_heads`: Number of attention heads - `negative_slope`: LeakyReLU slope - `concat`: Concatenate or average multi-head outputs **Use Cases:** - Protein-protein interaction prediction - Molecule generation with attention to reactive sites - Knowledge graph reasoning with relation importance ### GIN (Graph Isomorphism Network) **Type:** Maximally powerful message passing **Paper:** How Powerful are Graph Neural Networks? (Xu et al., 2019) **Characteristics:** - Theoretically most expressive GNN architecture - Injective aggregation function - Can distinguish graph structures GCN cannot - Often best performance on molecular tasks **Best For:** - Molecular property prediction (state-of-the-art) - Tasks requiring structural discrimination - Graph classification **Parameters:** - `input_dim`, `hidden_dims`: Standard dimensions - `edge_input_dim`: Include edge features - `batch_norm`: Typically use true - `readout`: Graph pooling ("sum", "mean", "max") - `eps`: Learnable or fixed epsilon **Use Cases:** - Drug property prediction (BBBP, HIV, etc.) - Molecular generation - Reaction prediction ### RGCN (Relational Graph Convolutional Network) **Type:** Multi-relational message passing **Paper:** Modeling Relational Data with Graph Convolutional Networks (Schlichtkrull et al., 2018) **Characteristics:** - Handles multiple edge/relation types - Relation-specific weight matrices - Basis decomposition for parameter efficiency - Essential for knowledge graphs **Best For:** - Knowledge graph reasoning - Heterogeneous molecular graphs - Multi-relational data **Parameters:** - `num_relation`: Number of relation types - `hidden_dims`: Layer dimensions - `num_bases`: Basis decomposition (reduce parameters) **Use Cases:** - Knowledge graph completion - Retrosynthesis (different bond types) - Protein interaction networks ### MPNN (Message Passing Neural Network) **Type:** General message passing framework **Paper:** Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017) **Characteristics:** - Flexible message and update functions - Edge features in message computation - GRU updates for node hidden states - Set2Set readout for graph representation **Best For:** - Quantum chemistry predictions - Tasks with important edge information - When node states evolve over multiple iterations **Parameters:** - `input_dim`, `hidden_dim`: Feature dimensions - `edge_input_dim`: Edge feature dimension - `num_layer`: Message passing iterations - `num_mlp_layer`: MLP layers in message function **Use Cases:** - QM9 quantum property prediction - Molecular dynamics - 3D conformation-aware tasks ### SchNet (Continuous-Filter Convolutional Network) **Type:** 3D geometry-aware convolution **Paper:** SchNet: A continuous-filter convolutional neural network (Schütt et al., 2017) **Characteristics:** - Operates on 3D atomic coordinates - Continuous filter convolutions - Rotation and translation invariant - Excellent for quantum chemistry **Best For:** - 3D molecular structure tasks - Quantum property prediction - Protein structure analysis - Energy and force prediction **Parameters:** - `input_dim`: Atom features - `hidden_dims`: Layer dimensions - `num_gaussian`: RBF basis functions for distances - `cutoff`: Interaction cutoff distance **Use Cases:** - QM9 property prediction - Molecular dynamics simulations - Protein-ligand binding with structures - Crystal property prediction ### ChebNet (Chebyshev Spectral CNN) **Type:** Spectral convolution **Paper:** Convolutional Neural Networks on Graphs (Defferrard et al., 2016) **Characteristics:** - Spectral graph convolution - Chebyshev polynomial approximation - Captures global graph structure - Computationally efficient **Best For:** - Tasks requiring global information - When graph Laplacian is informative - Theoretical analysis **Parameters:** - `input_dim`, `hidden_dims`: Dimensions - `num_cheb`: Order of Chebyshev polynomial **Use Cases:** - Citation network classification - Brain network analysis - Signal processing on graphs ### NFP (Neural Fingerprint) **Type:** Molecular fingerprint learning **Paper:** Convolutional Networks on Graphs for Learning Molecular Fingerprints (Duvenaud et al., 2015) **Characteristics:** - Learns differentiable molecular fingerprints - Alternative to hand-crafted fingerprints (ECFP) - Circular convolutions like ECFP - Interpretable learned features **Best For:** - Molecular similarity learning - Property prediction with limited data - When interpretability is important **Parameters:** - `input_dim`, `output_dim`: Feature dimensions - `hidden_dims`: Layer dimensions - `num_layer`: Circular convolution depth **Use Cases:** - Virtual screening - Molecular similarity search - QSAR modeling ## Protein-Specific Models ### GearNet (Geometry-Aware Relational Graph Network) **Type:** Protein structure encoder **Paper:** Protein Representation Learning by Geometric Structure Pretraining (Zhang et al., 2023) **Characteristics:** - Incorporates 3D geometric information - Multiple edge types (sequential, spatial, KNN) - Designed specifically for proteins - State-of-the-art on protein tasks **Best For:** - Protein structure prediction - Protein function prediction - Protein-protein interaction - Any task with protein 3D structures **Parameters:** - `input_dim`: Residue features - `hidden_dims`: Layer dimensions - `num_relation`: Edge types (sequence, radius, KNN) - `edge_input_dim`: Geometric features (distances, angles) - `batch_norm`: Typically true **Use Cases:** - Enzyme function prediction (EnzymeCommission) - Protein fold recognition - Contact prediction - Binding site identification ### ESM (Evolutionary Scale Modeling) **Type:** Protein language model (transformer) **Paper:** Biological structure and function emerge from scaling unsupervised learning (Rives et al., 2021) **Characteristics:** - Pre-trained on 250M+ protein sequences - Captures evolutionary and structural information - Transformer architecture - Transfer learning for downstream tasks **Best For:** - Any sequence-based protein task - When no structure available - Transfer learning with limited data **Variants:** - ESM-1b: 650M parameters - ESM-2: Multiple sizes (8M to 15B parameters) **Use Cases:** - Protein function prediction - Variant effect prediction - Protein design - Structure prediction (ESMFold) ### ProteinBERT **Type:** Masked language model for proteins **Characteristics:** - BERT-style pre-training - Masked amino acid prediction - Bidirectional context - Good for sequence-based tasks **Use Cases:** - Function annotation - Subcellular localization - Stability prediction ### ProteinCNN / ProteinResNet **Type:** Convolutional networks for sequences **Characteristics:** - 1D convolutions on sequences - Local pattern recognition - Faster than transformers - Good for motif detection **Use Cases:** - Binding site prediction - Secondary structure prediction - Domain identification ### ProteinLSTM **Type:** Recurrent network for sequences **Characteristics:** - Bidirectional LSTM - Captures long-range dependencies - Sequential processing - Good baseline for sequence tasks **Use Cases:** - Order prediction - Sequential annotation - Time-series protein data ## Knowledge Graph Models ### TransE (Translation Embedding) **Type:** Translation-based embedding **Paper:** Translating Embeddings for Modeling Multi-relational Data (Bordes et al., 2013) **Characteristics:** - h + r ≈ t (head + relation ≈ tail) - Simple and interpretable - Works well for 1-to-1 relations - Memory efficient **Best For:** - Large knowledge graphs - Initial experiments - Interpretable embeddings **Parameters:** - `num_entity`, `num_relation`: Graph size - `embedding_dim`: Embedding dimensions (typically 50-500) ### RotatE (Rotation Embedding) **Type:** Rotation in complex space **Paper:** RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (Sun et al., 2019) **Characteristics:** - Relations as rotations in complex space - Handles symmetric, antisymmetric, inverse, composition - State-of-the-art on many benchmarks **Best For:** - Most knowledge graph tasks - Complex relation patterns - When accuracy is critical **Parameters:** - `num_entity`, `num_relation`: Graph size - `embedding_dim`: Must be even (complex embeddings) - `max_score`: Score clipping value ### DistMult **Type:** Bilinear model **Characteristics:** - Symmetric relation modeling - Fast and efficient - Cannot model antisymmetric relations **Best For:** - Symmetric relations (e.g., "similar to") - When speed is critical - Large-scale graphs ### ComplEx **Type:** Complex-valued embeddings **Characteristics:** - Handles asymmetric and symmetric relations - Better than DistMult for most graphs - Good balance of expressiveness and efficiency **Best For:** - General knowledge graph completion - Mixed relation types - When RotatE is too complex ### SimplE **Type:** Enhanced embedding model **Characteristics:** - Two embeddings per entity (canonical + inverse) - Fully expressive - Slightly more parameters than ComplEx **Best For:** - When full expressiveness needed - Inverse relations are important ## Generative Models ### GraphAutoregressiveFlow **Type:** Normalizing flow for molecules **Characteristics:** - Exact likelihood computation - Invertible transformations - Stable training (no adversarial) - Conditional generation support **Best For:** - Molecular generation - Density estimation - Interpolation between molecules **Parameters:** - `input_dim`: Atom features - `hidden_dims`: Coupling layers - `num_flow`: Number of flow transformations **Use Cases:** - De novo drug design - Chemical space exploration - Property-targeted generation ## Pre-training Models ### InfoGraph **Type:** Contrastive learning **Characteristics:** - Maximizes mutual information - Graph-level and node-level contrast - Unsupervised pre-training - Good for small datasets **Use Cases:** - Pre-train molecular encoders - Few-shot learning - Transfer learning ### MultiviewContrast **Type:** Multi-view contrastive learning for proteins **Characteristics:** - Contrasts different views of proteins - Geometric pre-training - Uses 3D structure information - Excellent for protein models **Use Cases:** - Pre-train GearNet on protein structures - Transfer to property prediction - Limited labeled data scenarios ## Model Selection Guide ### By Task Type **Molecular Property Prediction:** 1. GIN (first choice) 2. GAT (interpretability) 3. SchNet (3D available) **Protein Tasks:** 1. ESM (sequence only) 2. GearNet (structure available) 3. ProteinBERT (sequence, lighter than ESM) **Knowledge Graphs:** 1. RotatE (best performance) 2. ComplEx (good balance) 3. TransE (large graphs, efficiency) **Molecular Generation:** 1. GraphAutoregressiveFlow (exact likelihood) 2. GCPN with GIN backbone (property optimization) **Retrosynthesis:** 1. GIN (synthon completion) 2. RGCN (center identification with bond types) ### By Dataset Size **Small (< 1k):** - Use pre-trained models (ESM for proteins) - Simpler architectures (GCN, ProteinCNN) - Heavy regularization **Medium (1k-100k):** - GIN for molecules - GAT for interpretability - Standard training **Large (> 100k):** - Any model works - Deeper architectures - Can train from scratch ### By Computational Budget **Low:** - GCN (simplest) - DistMult (KG) - ProteinLSTM **Medium:** - GIN - GAT - ComplEx **High:** - ESM (large) - SchNet (3D) - RotatE with high dim ## Implementation Tips 1. **Start Simple**: Begin with GCN or GIN baseline 2. **Use Pre-trained**: ESM for proteins, InfoGraph for molecules 3. **Tune Depth**: 3-5 layers typically sufficient 4. **Batch Normalization**: Usually helps (except KG embeddings) 5. **Residual Connections**: Important for deep networks 6. **Readout Function**: "mean" usually works well 7. **Edge Features**: Include when available (bonds, distances) 8. **Regularization**: Dropout, weight decay, early stopping