Files
gh-k-dense-ai-claude-scient…/skills/torchdrug/references/models_architectures.md
2025-11-30 08:30:10 +08:00

13 KiB

Models and Architectures

Overview

TorchDrug provides a comprehensive collection of pre-built model architectures for various graph-based learning tasks. This reference catalogs all available models with their characteristics, use cases, and implementation details.

Graph Neural Networks

GCN (Graph Convolutional Network)

Type: Spatial message passing Paper: Semi-Supervised Classification with Graph Convolutional Networks (Kipf & Welling, 2017)

Characteristics:

  • Simple and efficient aggregation
  • Normalized adjacency matrix convolution
  • Works well for homophilic graphs
  • Good baseline for many tasks

Best For:

  • Initial experiments and baselines
  • When computational efficiency is important
  • Graphs with clear local structure

Parameters:

  • input_dim: Node feature dimension
  • hidden_dims: List of hidden layer dimensions
  • edge_input_dim: Edge feature dimension (optional)
  • batch_norm: Apply batch normalization
  • activation: Activation function (relu, elu, etc.)
  • dropout: Dropout rate

Use Cases:

  • Molecular property prediction
  • Citation network classification
  • Social network analysis

GAT (Graph Attention Network)

Type: Attention-based message passing Paper: Graph Attention Networks (Veličković et al., 2018)

Characteristics:

  • Learns attention weights for neighbors
  • Different importance for different neighbors
  • Multi-head attention for robustness
  • Handles varying node degrees naturally

Best For:

  • When neighbor importance varies
  • Heterogeneous graphs
  • Interpretable predictions

Parameters:

  • input_dim, hidden_dims: Standard dimensions
  • num_heads: Number of attention heads
  • negative_slope: LeakyReLU slope
  • concat: Concatenate or average multi-head outputs

Use Cases:

  • Protein-protein interaction prediction
  • Molecule generation with attention to reactive sites
  • Knowledge graph reasoning with relation importance

GIN (Graph Isomorphism Network)

Type: Maximally powerful message passing Paper: How Powerful are Graph Neural Networks? (Xu et al., 2019)

Characteristics:

  • Theoretically most expressive GNN architecture
  • Injective aggregation function
  • Can distinguish graph structures GCN cannot
  • Often best performance on molecular tasks

Best For:

  • Molecular property prediction (state-of-the-art)
  • Tasks requiring structural discrimination
  • Graph classification

Parameters:

  • input_dim, hidden_dims: Standard dimensions
  • edge_input_dim: Include edge features
  • batch_norm: Typically use true
  • readout: Graph pooling ("sum", "mean", "max")
  • eps: Learnable or fixed epsilon

Use Cases:

  • Drug property prediction (BBBP, HIV, etc.)
  • Molecular generation
  • Reaction prediction

RGCN (Relational Graph Convolutional Network)

Type: Multi-relational message passing Paper: Modeling Relational Data with Graph Convolutional Networks (Schlichtkrull et al., 2018)

Characteristics:

  • Handles multiple edge/relation types
  • Relation-specific weight matrices
  • Basis decomposition for parameter efficiency
  • Essential for knowledge graphs

Best For:

  • Knowledge graph reasoning
  • Heterogeneous molecular graphs
  • Multi-relational data

Parameters:

  • num_relation: Number of relation types
  • hidden_dims: Layer dimensions
  • num_bases: Basis decomposition (reduce parameters)

Use Cases:

  • Knowledge graph completion
  • Retrosynthesis (different bond types)
  • Protein interaction networks

MPNN (Message Passing Neural Network)

Type: General message passing framework Paper: Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017)

Characteristics:

  • Flexible message and update functions
  • Edge features in message computation
  • GRU updates for node hidden states
  • Set2Set readout for graph representation

Best For:

  • Quantum chemistry predictions
  • Tasks with important edge information
  • When node states evolve over multiple iterations

Parameters:

  • input_dim, hidden_dim: Feature dimensions
  • edge_input_dim: Edge feature dimension
  • num_layer: Message passing iterations
  • num_mlp_layer: MLP layers in message function

Use Cases:

  • QM9 quantum property prediction
  • Molecular dynamics
  • 3D conformation-aware tasks

SchNet (Continuous-Filter Convolutional Network)

Type: 3D geometry-aware convolution Paper: SchNet: A continuous-filter convolutional neural network (Schütt et al., 2017)

Characteristics:

  • Operates on 3D atomic coordinates
  • Continuous filter convolutions
  • Rotation and translation invariant
  • Excellent for quantum chemistry

Best For:

  • 3D molecular structure tasks
  • Quantum property prediction
  • Protein structure analysis
  • Energy and force prediction

Parameters:

  • input_dim: Atom features
  • hidden_dims: Layer dimensions
  • num_gaussian: RBF basis functions for distances
  • cutoff: Interaction cutoff distance

Use Cases:

  • QM9 property prediction
  • Molecular dynamics simulations
  • Protein-ligand binding with structures
  • Crystal property prediction

ChebNet (Chebyshev Spectral CNN)

Type: Spectral convolution Paper: Convolutional Neural Networks on Graphs (Defferrard et al., 2016)

Characteristics:

  • Spectral graph convolution
  • Chebyshev polynomial approximation
  • Captures global graph structure
  • Computationally efficient

Best For:

  • Tasks requiring global information
  • When graph Laplacian is informative
  • Theoretical analysis

Parameters:

  • input_dim, hidden_dims: Dimensions
  • num_cheb: Order of Chebyshev polynomial

Use Cases:

  • Citation network classification
  • Brain network analysis
  • Signal processing on graphs

NFP (Neural Fingerprint)

Type: Molecular fingerprint learning Paper: Convolutional Networks on Graphs for Learning Molecular Fingerprints (Duvenaud et al., 2015)

Characteristics:

  • Learns differentiable molecular fingerprints
  • Alternative to hand-crafted fingerprints (ECFP)
  • Circular convolutions like ECFP
  • Interpretable learned features

Best For:

  • Molecular similarity learning
  • Property prediction with limited data
  • When interpretability is important

Parameters:

  • input_dim, output_dim: Feature dimensions
  • hidden_dims: Layer dimensions
  • num_layer: Circular convolution depth

Use Cases:

  • Virtual screening
  • Molecular similarity search
  • QSAR modeling

Protein-Specific Models

GearNet (Geometry-Aware Relational Graph Network)

Type: Protein structure encoder Paper: Protein Representation Learning by Geometric Structure Pretraining (Zhang et al., 2023)

Characteristics:

  • Incorporates 3D geometric information
  • Multiple edge types (sequential, spatial, KNN)
  • Designed specifically for proteins
  • State-of-the-art on protein tasks

Best For:

  • Protein structure prediction
  • Protein function prediction
  • Protein-protein interaction
  • Any task with protein 3D structures

Parameters:

  • input_dim: Residue features
  • hidden_dims: Layer dimensions
  • num_relation: Edge types (sequence, radius, KNN)
  • edge_input_dim: Geometric features (distances, angles)
  • batch_norm: Typically true

Use Cases:

  • Enzyme function prediction (EnzymeCommission)
  • Protein fold recognition
  • Contact prediction
  • Binding site identification

ESM (Evolutionary Scale Modeling)

Type: Protein language model (transformer) Paper: Biological structure and function emerge from scaling unsupervised learning (Rives et al., 2021)

Characteristics:

  • Pre-trained on 250M+ protein sequences
  • Captures evolutionary and structural information
  • Transformer architecture
  • Transfer learning for downstream tasks

Best For:

  • Any sequence-based protein task
  • When no structure available
  • Transfer learning with limited data

Variants:

  • ESM-1b: 650M parameters
  • ESM-2: Multiple sizes (8M to 15B parameters)

Use Cases:

  • Protein function prediction
  • Variant effect prediction
  • Protein design
  • Structure prediction (ESMFold)

ProteinBERT

Type: Masked language model for proteins

Characteristics:

  • BERT-style pre-training
  • Masked amino acid prediction
  • Bidirectional context
  • Good for sequence-based tasks

Use Cases:

  • Function annotation
  • Subcellular localization
  • Stability prediction

ProteinCNN / ProteinResNet

Type: Convolutional networks for sequences

Characteristics:

  • 1D convolutions on sequences
  • Local pattern recognition
  • Faster than transformers
  • Good for motif detection

Use Cases:

  • Binding site prediction
  • Secondary structure prediction
  • Domain identification

ProteinLSTM

Type: Recurrent network for sequences

Characteristics:

  • Bidirectional LSTM
  • Captures long-range dependencies
  • Sequential processing
  • Good baseline for sequence tasks

Use Cases:

  • Order prediction
  • Sequential annotation
  • Time-series protein data

Knowledge Graph Models

TransE (Translation Embedding)

Type: Translation-based embedding Paper: Translating Embeddings for Modeling Multi-relational Data (Bordes et al., 2013)

Characteristics:

  • h + r ≈ t (head + relation ≈ tail)
  • Simple and interpretable
  • Works well for 1-to-1 relations
  • Memory efficient

Best For:

  • Large knowledge graphs
  • Initial experiments
  • Interpretable embeddings

Parameters:

  • num_entity, num_relation: Graph size
  • embedding_dim: Embedding dimensions (typically 50-500)

RotatE (Rotation Embedding)

Type: Rotation in complex space Paper: RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (Sun et al., 2019)

Characteristics:

  • Relations as rotations in complex space
  • Handles symmetric, antisymmetric, inverse, composition
  • State-of-the-art on many benchmarks

Best For:

  • Most knowledge graph tasks
  • Complex relation patterns
  • When accuracy is critical

Parameters:

  • num_entity, num_relation: Graph size
  • embedding_dim: Must be even (complex embeddings)
  • max_score: Score clipping value

DistMult

Type: Bilinear model

Characteristics:

  • Symmetric relation modeling
  • Fast and efficient
  • Cannot model antisymmetric relations

Best For:

  • Symmetric relations (e.g., "similar to")
  • When speed is critical
  • Large-scale graphs

ComplEx

Type: Complex-valued embeddings

Characteristics:

  • Handles asymmetric and symmetric relations
  • Better than DistMult for most graphs
  • Good balance of expressiveness and efficiency

Best For:

  • General knowledge graph completion
  • Mixed relation types
  • When RotatE is too complex

SimplE

Type: Enhanced embedding model

Characteristics:

  • Two embeddings per entity (canonical + inverse)
  • Fully expressive
  • Slightly more parameters than ComplEx

Best For:

  • When full expressiveness needed
  • Inverse relations are important

Generative Models

GraphAutoregressiveFlow

Type: Normalizing flow for molecules

Characteristics:

  • Exact likelihood computation
  • Invertible transformations
  • Stable training (no adversarial)
  • Conditional generation support

Best For:

  • Molecular generation
  • Density estimation
  • Interpolation between molecules

Parameters:

  • input_dim: Atom features
  • hidden_dims: Coupling layers
  • num_flow: Number of flow transformations

Use Cases:

  • De novo drug design
  • Chemical space exploration
  • Property-targeted generation

Pre-training Models

InfoGraph

Type: Contrastive learning

Characteristics:

  • Maximizes mutual information
  • Graph-level and node-level contrast
  • Unsupervised pre-training
  • Good for small datasets

Use Cases:

  • Pre-train molecular encoders
  • Few-shot learning
  • Transfer learning

MultiviewContrast

Type: Multi-view contrastive learning for proteins

Characteristics:

  • Contrasts different views of proteins
  • Geometric pre-training
  • Uses 3D structure information
  • Excellent for protein models

Use Cases:

  • Pre-train GearNet on protein structures
  • Transfer to property prediction
  • Limited labeled data scenarios

Model Selection Guide

By Task Type

Molecular Property Prediction:

  1. GIN (first choice)
  2. GAT (interpretability)
  3. SchNet (3D available)

Protein Tasks:

  1. ESM (sequence only)
  2. GearNet (structure available)
  3. ProteinBERT (sequence, lighter than ESM)

Knowledge Graphs:

  1. RotatE (best performance)
  2. ComplEx (good balance)
  3. TransE (large graphs, efficiency)

Molecular Generation:

  1. GraphAutoregressiveFlow (exact likelihood)
  2. GCPN with GIN backbone (property optimization)

Retrosynthesis:

  1. GIN (synthon completion)
  2. RGCN (center identification with bond types)

By Dataset Size

Small (< 1k):

  • Use pre-trained models (ESM for proteins)
  • Simpler architectures (GCN, ProteinCNN)
  • Heavy regularization

Medium (1k-100k):

  • GIN for molecules
  • GAT for interpretability
  • Standard training

Large (> 100k):

  • Any model works
  • Deeper architectures
  • Can train from scratch

By Computational Budget

Low:

  • GCN (simplest)
  • DistMult (KG)
  • ProteinLSTM

Medium:

  • GIN
  • GAT
  • ComplEx

High:

  • ESM (large)
  • SchNet (3D)
  • RotatE with high dim

Implementation Tips

  1. Start Simple: Begin with GCN or GIN baseline
  2. Use Pre-trained: ESM for proteins, InfoGraph for molecules
  3. Tune Depth: 3-5 layers typically sufficient
  4. Batch Normalization: Usually helps (except KG embeddings)
  5. Residual Connections: Important for deep networks
  6. Readout Function: "mean" usually works well
  7. Edge Features: Include when available (bonds, distances)
  8. Regularization: Dropout, weight decay, early stopping