14 KiB
Core Concepts and Technical Details
Overview
This reference covers TorchDrug's fundamental architecture, design principles, and technical implementation details.
Architecture Philosophy
Modular Design
TorchDrug separates concerns into distinct modules:
- Representation Models (models.py): Encode graphs into embeddings
- Task Definitions (tasks.py): Define learning objectives and evaluation
- Data Handling (data.py, datasets.py): Graph structures and datasets
- Core Components (core.py): Base classes and utilities
Benefits:
- Reuse representations across tasks
- Mix and match components
- Easy experimentation and prototyping
- Clear separation of concerns
Configurable System
All components inherit from core.Configurable:
- Serialize to configuration dictionaries
- Reconstruct from configurations
- Save and load complete pipelines
- Reproducible experiments
Core Components
core.Configurable
Base class for all TorchDrug components.
Key Methods:
config_dict(): Serialize to dictionaryload_config_dict(config): Load from dictionarysave(file): Save to fileload(file): Load from file
Example:
from torchdrug import core, models
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
# Save configuration
config = model.config_dict()
# {'class': 'GIN', 'input_dim': 10, 'hidden_dims': [256, 256], ...}
# Reconstruct model
model2 = core.Configurable.load_config_dict(config)
core.Registry
Decorator for registering models, tasks, and datasets.
Usage:
from torchdrug import core as core_td
@core_td.register("models.CustomModel")
class CustomModel(nn.Module, core_td.Configurable):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
def forward(self, graph, input, all_loss, metric):
# Model implementation
pass
Benefits:
- Models automatically serializable
- String-based model specification
- Easy model lookup and instantiation
Data Structures
Graph
Core data structure representing molecular or protein graphs.
Attributes:
num_node: Number of nodesnum_edge: Number of edgesnode_feature: Node feature tensor [num_node, feature_dim]edge_feature: Edge feature tensor [num_edge, feature_dim]edge_list: Edge connectivity [num_edge, 2 or 3]num_relation: Number of edge types (for multi-relational)
Methods:
node_mask(mask): Select subset of nodesedge_mask(mask): Select subset of edgesundirected(): Make graph undirecteddirected(): Make graph directed
Batching:
- Graphs batched into single disconnected graph
- Automatic batching in DataLoader
- Preserves node/edge indices per graph
Molecule (extends Graph)
Specialized graph for molecules.
Additional Attributes:
atom_type: Atomic numbersbond_type: Bond types (single, double, triple, aromatic)formal_charge: Atomic formal chargesexplicit_hs: Explicit hydrogen counts
Methods:
from_smiles(smiles): Create from SMILES stringfrom_molecule(mol): Create from RDKit moleculeto_smiles(): Convert to SMILESto_molecule(): Convert to RDKit moleculeion_to_molecule(): Neutralize charges
Example:
from torchdrug import data
# From SMILES
mol = data.Molecule.from_smiles("CCO")
# Atom features
print(mol.atom_type) # [6, 6, 8] (C, C, O)
print(mol.bond_type) # [1, 1] (single bonds)
Protein (extends Graph)
Specialized graph for proteins.
Additional Attributes:
residue_type: Amino acid typesatom_name: Atom names (CA, CB, etc.)atom_type: Atomic numbersresidue_number: Residue numberingchain_id: Chain identifiers
Methods:
from_pdb(pdb_file): Load from PDB filefrom_sequence(sequence): Create from sequenceto_pdb(pdb_file): Save to PDB file
Graph Construction:
- Nodes typically represent residues (not atoms)
- Edges can be sequential, spatial (KNN), or contact-based
- Configurable edge construction strategies
Example:
from torchdrug import data
# Load protein
protein = data.Protein.from_pdb("1a3x.pdb")
# Build graph with multiple edge types
graph = protein.residue_graph(
node_position="ca", # Use Cα positions
edge_types=["sequential", "radius"] # Sequential + spatial edges
)
PackedGraph
Efficient batching structure for heterogeneous graphs.
Purpose:
- Batch graphs of different sizes
- Single GPU memory allocation
- Efficient parallel processing
Attributes:
num_nodes: List of node counts per graphnum_edges: List of edge counts per graphgraph_ind: Graph index for each node
Use Cases:
- Automatic in DataLoader
- Custom batching strategies
- Multi-graph operations
Model Interface
Forward Function Signature
All TorchDrug models follow a standardized interface:
def forward(self, graph, input, all_loss=None, metric=None):
"""
Args:
graph (Graph): Batch of graphs
input (Tensor): Node input features
all_loss (Tensor, optional): Accumulator for losses
metric (dict, optional): Dictionary for metrics
Returns:
dict: Output dictionary with representation keys
"""
# Model computation
output = self.layers(graph, input)
return {
"node_feature": output,
"graph_feature": graph_pooling(output)
}
Key Points:
graph: Batched graph structureinput: Node features [num_node, input_dim]all_loss: Accumulated loss (for multi-task)metric: Shared metric dictionary- Returns dict with representation types
Essential Attributes
All models must define:
input_dim: Expected input feature dimensionoutput_dim: Output representation dimension
Purpose:
- Automatic dimension checking
- Compose models in pipelines
- Error checking and validation
Example:
class CustomModel(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.input_dim = input_dim
self.output_dim = hidden_dim
# ... layers ...
Task Interface
Core Task Methods
All tasks implement these methods:
class CustomTask(tasks.Task):
def preprocess(self, train_set, valid_set, test_set):
"""Dataset-specific preprocessing (optional)"""
pass
def predict(self, batch):
"""Generate predictions for a batch"""
graph, label = batch
output = self.model(graph, graph.node_feature)
pred = self.mlp(output["graph_feature"])
return pred
def target(self, batch):
"""Extract ground truth labels"""
graph, label = batch
return label
def forward(self, batch):
"""Compute training loss"""
pred = self.predict(batch)
target = self.target(batch)
loss = self.criterion(pred, target)
return loss
def evaluate(self, pred, target):
"""Compute evaluation metrics"""
metrics = {}
metrics["auroc"] = compute_auroc(pred, target)
metrics["auprc"] = compute_auprc(pred, target)
return metrics
Task Components
Typical Task Structure:
- Representation Model: Encodes graph to embeddings
- Readout/Prediction Head: Maps embeddings to predictions
- Loss Function: Training objective
- Metrics: Evaluation measures
Example:
from torchdrug import tasks, models
# Representation model
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
# Task wraps model with prediction head
task = tasks.PropertyPrediction(
model=model,
task=["task1", "task2"], # Multi-task
criterion="bce",
metric=["auroc", "auprc"],
num_mlp_layer=2
)
Training Workflow
Standard Training Loop
import torch
from torch.utils.data import DataLoader
from torchdrug import core, models, tasks, datasets
# 1. Load dataset
dataset = datasets.BBBP("~/datasets/")
train_set, valid_set, test_set = dataset.split()
# 2. Create data loaders
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=32)
# 3. Define model and task
model = models.GIN(input_dim=dataset.node_feature_dim,
hidden_dims=[256, 256, 256])
task = tasks.PropertyPrediction(model, task=dataset.tasks,
criterion="bce", metric=["auroc", "auprc"])
# 4. Setup optimizer
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
# 5. Training loop
for epoch in range(100):
# Train
task.train()
for batch in train_loader:
loss = task(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Validate
task.eval()
preds, targets = [], []
for batch in valid_loader:
pred = task.predict(batch)
target = task.target(batch)
preds.append(pred)
targets.append(target)
preds = torch.cat(preds)
targets = torch.cat(targets)
metrics = task.evaluate(preds, targets)
print(f"Epoch {epoch}: {metrics}")
PyTorch Lightning Integration
TorchDrug tasks are compatible with PyTorch Lightning:
import pytorch_lightning as pl
class LightningWrapper(pl.LightningModule):
def __init__(self, task):
super().__init__()
self.task = task
def training_step(self, batch, batch_idx):
loss = self.task(batch)
return loss
def validation_step(self, batch, batch_idx):
pred = self.task.predict(batch)
target = self.task.target(batch)
return {"pred": pred, "target": target}
def validation_epoch_end(self, outputs):
preds = torch.cat([o["pred"] for o in outputs])
targets = torch.cat([o["target"] for o in outputs])
metrics = self.task.evaluate(preds, targets)
self.log_dict(metrics)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
Loss Functions
Built-in Criteria
Classification:
"bce": Binary cross-entropy"ce": Cross-entropy (multi-class)
Regression:
"mse": Mean squared error"mae": Mean absolute error
Knowledge Graph:
"bce": Binary classification of triples"ce": Cross-entropy ranking loss"margin": Margin-based ranking
Custom Loss
class CustomTask(tasks.Task):
def forward(self, batch):
pred = self.predict(batch)
target = self.target(batch)
# Custom loss computation
loss = custom_loss_function(pred, target)
return loss
Metrics
Common Metrics
Classification:
- AUROC: Area under ROC curve
- AUPRC: Area under precision-recall curve
- Accuracy: Overall accuracy
- F1: Harmonic mean of precision and recall
Regression:
- MAE: Mean absolute error
- RMSE: Root mean squared error
- R²: Coefficient of determination
- Pearson: Pearson correlation
Ranking (Knowledge Graph):
- MR: Mean rank
- MRR: Mean reciprocal rank
- Hits@K: Percentage in top K
Multi-Task Metrics
For multi-label or multi-task:
- Metrics computed per task
- Macro-average across tasks
- Can weight by task importance
Data Transforms
Molecule Transforms
from torchdrug import transforms
# Add virtual node connected to all atoms
transform1 = transforms.VirtualNode()
# Add virtual edges
transform2 = transforms.VirtualEdge()
# Compose transforms
transform = transforms.Compose([transform1, transform2])
dataset = datasets.BBBP("~/datasets/", transform=transform)
Protein Transforms
# Add edges based on spatial proximity
transform = transforms.TruncateProtein(max_length=500)
dataset = datasets.Fold("~/datasets/", transform=transform)
Best Practices
Memory Efficiency
- Gradient Accumulation: For large models
- Mixed Precision: FP16 training
- Batch Size Tuning: Balance speed and memory
- Data Loading: Multiple workers for I/O
Reproducibility
- Set Seeds: PyTorch, NumPy, Python random
- Deterministic Operations:
torch.use_deterministic_algorithms(True) - Save Configurations: Use
core.Configurable - Version Control: Track TorchDrug version
Debugging
- Check Dimensions: Verify
input_dimandoutput_dim - Validate Batching: Print batch statistics
- Monitor Gradients: Watch for vanishing/exploding
- Overfit Small Batch: Ensure model capacity
Performance Optimization
- GPU Utilization: Monitor with
nvidia-smi - Profile Code: Use PyTorch profiler
- Optimize Data Loading: Prefetch, pin memory
- Compile Models: Use TorchScript if possible
Advanced Topics
Multi-Task Learning
Train single model on multiple related tasks:
task = tasks.PropertyPrediction(
model,
task=["task1", "task2", "task3"],
criterion="bce",
metric=["auroc"],
task_weight=[1.0, 1.0, 2.0] # Weight task 3 more
)
Transfer Learning
- Pre-train on large dataset
- Fine-tune on target dataset
- Optionally freeze early layers
Self-Supervised Pre-training
Use pre-training tasks:
AttributeMasking: Mask node featuresEdgePrediction: Predict edge existenceContextPrediction: Contrastive learning
Custom Layers
Extend TorchDrug with custom GNN layers:
from torchdrug import layers
class CustomConv(layers.MessagePassingBase):
def message(self, graph, input):
# Custom message function
pass
def aggregate(self, graph, message):
# Custom aggregation
pass
def combine(self, input, update):
# Custom combination
pass
Common Pitfalls
- Forgetting
input_dimandoutput_dim: Models won't compose - Not Batching Properly: Use PackedGraph for variable-sized graphs
- Data Leakage: Be careful with scaffold splits and pre-training
- Ignoring Edge Features: Bonds/spatial info can be critical
- Wrong Evaluation Metrics: Match metrics to task (AUROC for imbalanced)
- Insufficient Regularization: Use dropout, weight decay, early stopping
- Not Validating Chemistry: Generated molecules must be valid
- Overfitting Small Datasets: Use pre-training or simpler models