Files
gh-k-dense-ai-claude-scient…/skills/torchdrug/references/core_concepts.md
2025-11-30 08:30:10 +08:00

14 KiB
Raw Blame History

Core Concepts and Technical Details

Overview

This reference covers TorchDrug's fundamental architecture, design principles, and technical implementation details.

Architecture Philosophy

Modular Design

TorchDrug separates concerns into distinct modules:

  1. Representation Models (models.py): Encode graphs into embeddings
  2. Task Definitions (tasks.py): Define learning objectives and evaluation
  3. Data Handling (data.py, datasets.py): Graph structures and datasets
  4. Core Components (core.py): Base classes and utilities

Benefits:

  • Reuse representations across tasks
  • Mix and match components
  • Easy experimentation and prototyping
  • Clear separation of concerns

Configurable System

All components inherit from core.Configurable:

  • Serialize to configuration dictionaries
  • Reconstruct from configurations
  • Save and load complete pipelines
  • Reproducible experiments

Core Components

core.Configurable

Base class for all TorchDrug components.

Key Methods:

  • config_dict(): Serialize to dictionary
  • load_config_dict(config): Load from dictionary
  • save(file): Save to file
  • load(file): Load from file

Example:

from torchdrug import core, models

model = models.GIN(input_dim=10, hidden_dims=[256, 256])

# Save configuration
config = model.config_dict()
# {'class': 'GIN', 'input_dim': 10, 'hidden_dims': [256, 256], ...}

# Reconstruct model
model2 = core.Configurable.load_config_dict(config)

core.Registry

Decorator for registering models, tasks, and datasets.

Usage:

from torchdrug import core as core_td

@core_td.register("models.CustomModel")
class CustomModel(nn.Module, core_td.Configurable):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)

    def forward(self, graph, input, all_loss, metric):
        # Model implementation
        pass

Benefits:

  • Models automatically serializable
  • String-based model specification
  • Easy model lookup and instantiation

Data Structures

Graph

Core data structure representing molecular or protein graphs.

Attributes:

  • num_node: Number of nodes
  • num_edge: Number of edges
  • node_feature: Node feature tensor [num_node, feature_dim]
  • edge_feature: Edge feature tensor [num_edge, feature_dim]
  • edge_list: Edge connectivity [num_edge, 2 or 3]
  • num_relation: Number of edge types (for multi-relational)

Methods:

  • node_mask(mask): Select subset of nodes
  • edge_mask(mask): Select subset of edges
  • undirected(): Make graph undirected
  • directed(): Make graph directed

Batching:

  • Graphs batched into single disconnected graph
  • Automatic batching in DataLoader
  • Preserves node/edge indices per graph

Molecule (extends Graph)

Specialized graph for molecules.

Additional Attributes:

  • atom_type: Atomic numbers
  • bond_type: Bond types (single, double, triple, aromatic)
  • formal_charge: Atomic formal charges
  • explicit_hs: Explicit hydrogen counts

Methods:

  • from_smiles(smiles): Create from SMILES string
  • from_molecule(mol): Create from RDKit molecule
  • to_smiles(): Convert to SMILES
  • to_molecule(): Convert to RDKit molecule
  • ion_to_molecule(): Neutralize charges

Example:

from torchdrug import data

# From SMILES
mol = data.Molecule.from_smiles("CCO")

# Atom features
print(mol.atom_type)  # [6, 6, 8] (C, C, O)
print(mol.bond_type)  # [1, 1] (single bonds)

Protein (extends Graph)

Specialized graph for proteins.

Additional Attributes:

  • residue_type: Amino acid types
  • atom_name: Atom names (CA, CB, etc.)
  • atom_type: Atomic numbers
  • residue_number: Residue numbering
  • chain_id: Chain identifiers

Methods:

  • from_pdb(pdb_file): Load from PDB file
  • from_sequence(sequence): Create from sequence
  • to_pdb(pdb_file): Save to PDB file

Graph Construction:

  • Nodes typically represent residues (not atoms)
  • Edges can be sequential, spatial (KNN), or contact-based
  • Configurable edge construction strategies

Example:

from torchdrug import data

# Load protein
protein = data.Protein.from_pdb("1a3x.pdb")

# Build graph with multiple edge types
graph = protein.residue_graph(
    node_position="ca",  # Use Cα positions
    edge_types=["sequential", "radius"]  # Sequential + spatial edges
)

PackedGraph

Efficient batching structure for heterogeneous graphs.

Purpose:

  • Batch graphs of different sizes
  • Single GPU memory allocation
  • Efficient parallel processing

Attributes:

  • num_nodes: List of node counts per graph
  • num_edges: List of edge counts per graph
  • graph_ind: Graph index for each node

Use Cases:

  • Automatic in DataLoader
  • Custom batching strategies
  • Multi-graph operations

Model Interface

Forward Function Signature

All TorchDrug models follow a standardized interface:

def forward(self, graph, input, all_loss=None, metric=None):
    """
    Args:
        graph (Graph): Batch of graphs
        input (Tensor): Node input features
        all_loss (Tensor, optional): Accumulator for losses
        metric (dict, optional): Dictionary for metrics

    Returns:
        dict: Output dictionary with representation keys
    """
    # Model computation
    output = self.layers(graph, input)

    return {
        "node_feature": output,
        "graph_feature": graph_pooling(output)
    }

Key Points:

  • graph: Batched graph structure
  • input: Node features [num_node, input_dim]
  • all_loss: Accumulated loss (for multi-task)
  • metric: Shared metric dictionary
  • Returns dict with representation types

Essential Attributes

All models must define:

  • input_dim: Expected input feature dimension
  • output_dim: Output representation dimension

Purpose:

  • Automatic dimension checking
  • Compose models in pipelines
  • Error checking and validation

Example:

class CustomModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = hidden_dim
        # ... layers ...

Task Interface

Core Task Methods

All tasks implement these methods:

class CustomTask(tasks.Task):
    def preprocess(self, train_set, valid_set, test_set):
        """Dataset-specific preprocessing (optional)"""
        pass

    def predict(self, batch):
        """Generate predictions for a batch"""
        graph, label = batch
        output = self.model(graph, graph.node_feature)
        pred = self.mlp(output["graph_feature"])
        return pred

    def target(self, batch):
        """Extract ground truth labels"""
        graph, label = batch
        return label

    def forward(self, batch):
        """Compute training loss"""
        pred = self.predict(batch)
        target = self.target(batch)
        loss = self.criterion(pred, target)
        return loss

    def evaluate(self, pred, target):
        """Compute evaluation metrics"""
        metrics = {}
        metrics["auroc"] = compute_auroc(pred, target)
        metrics["auprc"] = compute_auprc(pred, target)
        return metrics

Task Components

Typical Task Structure:

  1. Representation Model: Encodes graph to embeddings
  2. Readout/Prediction Head: Maps embeddings to predictions
  3. Loss Function: Training objective
  4. Metrics: Evaluation measures

Example:

from torchdrug import tasks, models

# Representation model
model = models.GIN(input_dim=10, hidden_dims=[256, 256])

# Task wraps model with prediction head
task = tasks.PropertyPrediction(
    model=model,
    task=["task1", "task2"],  # Multi-task
    criterion="bce",
    metric=["auroc", "auprc"],
    num_mlp_layer=2
)

Training Workflow

Standard Training Loop

import torch
from torch.utils.data import DataLoader
from torchdrug import core, models, tasks, datasets

# 1. Load dataset
dataset = datasets.BBBP("~/datasets/")
train_set, valid_set, test_set = dataset.split()

# 2. Create data loaders
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=32)

# 3. Define model and task
model = models.GIN(input_dim=dataset.node_feature_dim,
                   hidden_dims=[256, 256, 256])
task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                 criterion="bce", metric=["auroc", "auprc"])

# 4. Setup optimizer
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)

# 5. Training loop
for epoch in range(100):
    # Train
    task.train()
    for batch in train_loader:
        loss = task(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validate
    task.eval()
    preds, targets = [], []
    for batch in valid_loader:
        pred = task.predict(batch)
        target = task.target(batch)
        preds.append(pred)
        targets.append(target)

    preds = torch.cat(preds)
    targets = torch.cat(targets)
    metrics = task.evaluate(preds, targets)
    print(f"Epoch {epoch}: {metrics}")

PyTorch Lightning Integration

TorchDrug tasks are compatible with PyTorch Lightning:

import pytorch_lightning as pl

class LightningWrapper(pl.LightningModule):
    def __init__(self, task):
        super().__init__()
        self.task = task

    def training_step(self, batch, batch_idx):
        loss = self.task(batch)
        return loss

    def validation_step(self, batch, batch_idx):
        pred = self.task.predict(batch)
        target = self.task.target(batch)
        return {"pred": pred, "target": target}

    def validation_epoch_end(self, outputs):
        preds = torch.cat([o["pred"] for o in outputs])
        targets = torch.cat([o["target"] for o in outputs])
        metrics = self.task.evaluate(preds, targets)
        self.log_dict(metrics)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Loss Functions

Built-in Criteria

Classification:

  • "bce": Binary cross-entropy
  • "ce": Cross-entropy (multi-class)

Regression:

  • "mse": Mean squared error
  • "mae": Mean absolute error

Knowledge Graph:

  • "bce": Binary classification of triples
  • "ce": Cross-entropy ranking loss
  • "margin": Margin-based ranking

Custom Loss

class CustomTask(tasks.Task):
    def forward(self, batch):
        pred = self.predict(batch)
        target = self.target(batch)

        # Custom loss computation
        loss = custom_loss_function(pred, target)

        return loss

Metrics

Common Metrics

Classification:

  • AUROC: Area under ROC curve
  • AUPRC: Area under precision-recall curve
  • Accuracy: Overall accuracy
  • F1: Harmonic mean of precision and recall

Regression:

  • MAE: Mean absolute error
  • RMSE: Root mean squared error
  • : Coefficient of determination
  • Pearson: Pearson correlation

Ranking (Knowledge Graph):

  • MR: Mean rank
  • MRR: Mean reciprocal rank
  • Hits@K: Percentage in top K

Multi-Task Metrics

For multi-label or multi-task:

  • Metrics computed per task
  • Macro-average across tasks
  • Can weight by task importance

Data Transforms

Molecule Transforms

from torchdrug import transforms

# Add virtual node connected to all atoms
transform1 = transforms.VirtualNode()

# Add virtual edges
transform2 = transforms.VirtualEdge()

# Compose transforms
transform = transforms.Compose([transform1, transform2])

dataset = datasets.BBBP("~/datasets/", transform=transform)

Protein Transforms

# Add edges based on spatial proximity
transform = transforms.TruncateProtein(max_length=500)

dataset = datasets.Fold("~/datasets/", transform=transform)

Best Practices

Memory Efficiency

  1. Gradient Accumulation: For large models
  2. Mixed Precision: FP16 training
  3. Batch Size Tuning: Balance speed and memory
  4. Data Loading: Multiple workers for I/O

Reproducibility

  1. Set Seeds: PyTorch, NumPy, Python random
  2. Deterministic Operations: torch.use_deterministic_algorithms(True)
  3. Save Configurations: Use core.Configurable
  4. Version Control: Track TorchDrug version

Debugging

  1. Check Dimensions: Verify input_dim and output_dim
  2. Validate Batching: Print batch statistics
  3. Monitor Gradients: Watch for vanishing/exploding
  4. Overfit Small Batch: Ensure model capacity

Performance Optimization

  1. GPU Utilization: Monitor with nvidia-smi
  2. Profile Code: Use PyTorch profiler
  3. Optimize Data Loading: Prefetch, pin memory
  4. Compile Models: Use TorchScript if possible

Advanced Topics

Multi-Task Learning

Train single model on multiple related tasks:

task = tasks.PropertyPrediction(
    model,
    task=["task1", "task2", "task3"],
    criterion="bce",
    metric=["auroc"],
    task_weight=[1.0, 1.0, 2.0]  # Weight task 3 more
)

Transfer Learning

  1. Pre-train on large dataset
  2. Fine-tune on target dataset
  3. Optionally freeze early layers

Self-Supervised Pre-training

Use pre-training tasks:

  • AttributeMasking: Mask node features
  • EdgePrediction: Predict edge existence
  • ContextPrediction: Contrastive learning

Custom Layers

Extend TorchDrug with custom GNN layers:

from torchdrug import layers

class CustomConv(layers.MessagePassingBase):
    def message(self, graph, input):
        # Custom message function
        pass

    def aggregate(self, graph, message):
        # Custom aggregation
        pass

    def combine(self, input, update):
        # Custom combination
        pass

Common Pitfalls

  1. Forgetting input_dim and output_dim: Models won't compose
  2. Not Batching Properly: Use PackedGraph for variable-sized graphs
  3. Data Leakage: Be careful with scaffold splits and pre-training
  4. Ignoring Edge Features: Bonds/spatial info can be critical
  5. Wrong Evaluation Metrics: Match metrics to task (AUROC for imbalanced)
  6. Insufficient Regularization: Use dropout, weight decay, early stopping
  7. Not Validating Chemistry: Generated molecules must be valid
  8. Overfitting Small Datasets: Use pre-training or simpler models