# Core Concepts and Technical Details ## Overview This reference covers TorchDrug's fundamental architecture, design principles, and technical implementation details. ## Architecture Philosophy ### Modular Design TorchDrug separates concerns into distinct modules: 1. **Representation Models** (models.py): Encode graphs into embeddings 2. **Task Definitions** (tasks.py): Define learning objectives and evaluation 3. **Data Handling** (data.py, datasets.py): Graph structures and datasets 4. **Core Components** (core.py): Base classes and utilities **Benefits:** - Reuse representations across tasks - Mix and match components - Easy experimentation and prototyping - Clear separation of concerns ### Configurable System All components inherit from `core.Configurable`: - Serialize to configuration dictionaries - Reconstruct from configurations - Save and load complete pipelines - Reproducible experiments ## Core Components ### core.Configurable Base class for all TorchDrug components. **Key Methods:** - `config_dict()`: Serialize to dictionary - `load_config_dict(config)`: Load from dictionary - `save(file)`: Save to file - `load(file)`: Load from file **Example:** ```python from torchdrug import core, models model = models.GIN(input_dim=10, hidden_dims=[256, 256]) # Save configuration config = model.config_dict() # {'class': 'GIN', 'input_dim': 10, 'hidden_dims': [256, 256], ...} # Reconstruct model model2 = core.Configurable.load_config_dict(config) ``` ### core.Registry Decorator for registering models, tasks, and datasets. **Usage:** ```python from torchdrug import core as core_td @core_td.register("models.CustomModel") class CustomModel(nn.Module, core_td.Configurable): def __init__(self, input_dim, hidden_dim): super().__init__() self.linear = nn.Linear(input_dim, hidden_dim) def forward(self, graph, input, all_loss, metric): # Model implementation pass ``` **Benefits:** - Models automatically serializable - String-based model specification - Easy model lookup and instantiation ## Data Structures ### Graph Core data structure representing molecular or protein graphs. **Attributes:** - `num_node`: Number of nodes - `num_edge`: Number of edges - `node_feature`: Node feature tensor [num_node, feature_dim] - `edge_feature`: Edge feature tensor [num_edge, feature_dim] - `edge_list`: Edge connectivity [num_edge, 2 or 3] - `num_relation`: Number of edge types (for multi-relational) **Methods:** - `node_mask(mask)`: Select subset of nodes - `edge_mask(mask)`: Select subset of edges - `undirected()`: Make graph undirected - `directed()`: Make graph directed **Batching:** - Graphs batched into single disconnected graph - Automatic batching in DataLoader - Preserves node/edge indices per graph ### Molecule (extends Graph) Specialized graph for molecules. **Additional Attributes:** - `atom_type`: Atomic numbers - `bond_type`: Bond types (single, double, triple, aromatic) - `formal_charge`: Atomic formal charges - `explicit_hs`: Explicit hydrogen counts **Methods:** - `from_smiles(smiles)`: Create from SMILES string - `from_molecule(mol)`: Create from RDKit molecule - `to_smiles()`: Convert to SMILES - `to_molecule()`: Convert to RDKit molecule - `ion_to_molecule()`: Neutralize charges **Example:** ```python from torchdrug import data # From SMILES mol = data.Molecule.from_smiles("CCO") # Atom features print(mol.atom_type) # [6, 6, 8] (C, C, O) print(mol.bond_type) # [1, 1] (single bonds) ``` ### Protein (extends Graph) Specialized graph for proteins. **Additional Attributes:** - `residue_type`: Amino acid types - `atom_name`: Atom names (CA, CB, etc.) - `atom_type`: Atomic numbers - `residue_number`: Residue numbering - `chain_id`: Chain identifiers **Methods:** - `from_pdb(pdb_file)`: Load from PDB file - `from_sequence(sequence)`: Create from sequence - `to_pdb(pdb_file)`: Save to PDB file **Graph Construction:** - Nodes typically represent residues (not atoms) - Edges can be sequential, spatial (KNN), or contact-based - Configurable edge construction strategies **Example:** ```python from torchdrug import data # Load protein protein = data.Protein.from_pdb("1a3x.pdb") # Build graph with multiple edge types graph = protein.residue_graph( node_position="ca", # Use Cα positions edge_types=["sequential", "radius"] # Sequential + spatial edges ) ``` ### PackedGraph Efficient batching structure for heterogeneous graphs. **Purpose:** - Batch graphs of different sizes - Single GPU memory allocation - Efficient parallel processing **Attributes:** - `num_nodes`: List of node counts per graph - `num_edges`: List of edge counts per graph - `graph_ind`: Graph index for each node **Use Cases:** - Automatic in DataLoader - Custom batching strategies - Multi-graph operations ## Model Interface ### Forward Function Signature All TorchDrug models follow a standardized interface: ```python def forward(self, graph, input, all_loss=None, metric=None): """ Args: graph (Graph): Batch of graphs input (Tensor): Node input features all_loss (Tensor, optional): Accumulator for losses metric (dict, optional): Dictionary for metrics Returns: dict: Output dictionary with representation keys """ # Model computation output = self.layers(graph, input) return { "node_feature": output, "graph_feature": graph_pooling(output) } ``` **Key Points:** - `graph`: Batched graph structure - `input`: Node features [num_node, input_dim] - `all_loss`: Accumulated loss (for multi-task) - `metric`: Shared metric dictionary - Returns dict with representation types ### Essential Attributes **All models must define:** - `input_dim`: Expected input feature dimension - `output_dim`: Output representation dimension **Purpose:** - Automatic dimension checking - Compose models in pipelines - Error checking and validation **Example:** ```python class CustomModel(nn.Module): def __init__(self, input_dim, hidden_dim): super().__init__() self.input_dim = input_dim self.output_dim = hidden_dim # ... layers ... ``` ## Task Interface ### Core Task Methods All tasks implement these methods: ```python class CustomTask(tasks.Task): def preprocess(self, train_set, valid_set, test_set): """Dataset-specific preprocessing (optional)""" pass def predict(self, batch): """Generate predictions for a batch""" graph, label = batch output = self.model(graph, graph.node_feature) pred = self.mlp(output["graph_feature"]) return pred def target(self, batch): """Extract ground truth labels""" graph, label = batch return label def forward(self, batch): """Compute training loss""" pred = self.predict(batch) target = self.target(batch) loss = self.criterion(pred, target) return loss def evaluate(self, pred, target): """Compute evaluation metrics""" metrics = {} metrics["auroc"] = compute_auroc(pred, target) metrics["auprc"] = compute_auprc(pred, target) return metrics ``` ### Task Components **Typical Task Structure:** 1. **Representation Model**: Encodes graph to embeddings 2. **Readout/Prediction Head**: Maps embeddings to predictions 3. **Loss Function**: Training objective 4. **Metrics**: Evaluation measures **Example:** ```python from torchdrug import tasks, models # Representation model model = models.GIN(input_dim=10, hidden_dims=[256, 256]) # Task wraps model with prediction head task = tasks.PropertyPrediction( model=model, task=["task1", "task2"], # Multi-task criterion="bce", metric=["auroc", "auprc"], num_mlp_layer=2 ) ``` ## Training Workflow ### Standard Training Loop ```python import torch from torch.utils.data import DataLoader from torchdrug import core, models, tasks, datasets # 1. Load dataset dataset = datasets.BBBP("~/datasets/") train_set, valid_set, test_set = dataset.split() # 2. Create data loaders train_loader = DataLoader(train_set, batch_size=32, shuffle=True) valid_loader = DataLoader(valid_set, batch_size=32) # 3. Define model and task model = models.GIN(input_dim=dataset.node_feature_dim, hidden_dims=[256, 256, 256]) task = tasks.PropertyPrediction(model, task=dataset.tasks, criterion="bce", metric=["auroc", "auprc"]) # 4. Setup optimizer optimizer = torch.optim.Adam(task.parameters(), lr=1e-3) # 5. Training loop for epoch in range(100): # Train task.train() for batch in train_loader: loss = task(batch) optimizer.zero_grad() loss.backward() optimizer.step() # Validate task.eval() preds, targets = [], [] for batch in valid_loader: pred = task.predict(batch) target = task.target(batch) preds.append(pred) targets.append(target) preds = torch.cat(preds) targets = torch.cat(targets) metrics = task.evaluate(preds, targets) print(f"Epoch {epoch}: {metrics}") ``` ### PyTorch Lightning Integration TorchDrug tasks are compatible with PyTorch Lightning: ```python import pytorch_lightning as pl class LightningWrapper(pl.LightningModule): def __init__(self, task): super().__init__() self.task = task def training_step(self, batch, batch_idx): loss = self.task(batch) return loss def validation_step(self, batch, batch_idx): pred = self.task.predict(batch) target = self.task.target(batch) return {"pred": pred, "target": target} def validation_epoch_end(self, outputs): preds = torch.cat([o["pred"] for o in outputs]) targets = torch.cat([o["target"] for o in outputs]) metrics = self.task.evaluate(preds, targets) self.log_dict(metrics) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) ``` ## Loss Functions ### Built-in Criteria **Classification:** - `"bce"`: Binary cross-entropy - `"ce"`: Cross-entropy (multi-class) **Regression:** - `"mse"`: Mean squared error - `"mae"`: Mean absolute error **Knowledge Graph:** - `"bce"`: Binary classification of triples - `"ce"`: Cross-entropy ranking loss - `"margin"`: Margin-based ranking ### Custom Loss ```python class CustomTask(tasks.Task): def forward(self, batch): pred = self.predict(batch) target = self.target(batch) # Custom loss computation loss = custom_loss_function(pred, target) return loss ``` ## Metrics ### Common Metrics **Classification:** - **AUROC**: Area under ROC curve - **AUPRC**: Area under precision-recall curve - **Accuracy**: Overall accuracy - **F1**: Harmonic mean of precision and recall **Regression:** - **MAE**: Mean absolute error - **RMSE**: Root mean squared error - **R²**: Coefficient of determination - **Pearson**: Pearson correlation **Ranking (Knowledge Graph):** - **MR**: Mean rank - **MRR**: Mean reciprocal rank - **Hits@K**: Percentage in top K ### Multi-Task Metrics For multi-label or multi-task: - Metrics computed per task - Macro-average across tasks - Can weight by task importance ## Data Transforms ### Molecule Transforms ```python from torchdrug import transforms # Add virtual node connected to all atoms transform1 = transforms.VirtualNode() # Add virtual edges transform2 = transforms.VirtualEdge() # Compose transforms transform = transforms.Compose([transform1, transform2]) dataset = datasets.BBBP("~/datasets/", transform=transform) ``` ### Protein Transforms ```python # Add edges based on spatial proximity transform = transforms.TruncateProtein(max_length=500) dataset = datasets.Fold("~/datasets/", transform=transform) ``` ## Best Practices ### Memory Efficiency 1. **Gradient Accumulation**: For large models 2. **Mixed Precision**: FP16 training 3. **Batch Size Tuning**: Balance speed and memory 4. **Data Loading**: Multiple workers for I/O ### Reproducibility 1. **Set Seeds**: PyTorch, NumPy, Python random 2. **Deterministic Operations**: `torch.use_deterministic_algorithms(True)` 3. **Save Configurations**: Use `core.Configurable` 4. **Version Control**: Track TorchDrug version ### Debugging 1. **Check Dimensions**: Verify `input_dim` and `output_dim` 2. **Validate Batching**: Print batch statistics 3. **Monitor Gradients**: Watch for vanishing/exploding 4. **Overfit Small Batch**: Ensure model capacity ### Performance Optimization 1. **GPU Utilization**: Monitor with `nvidia-smi` 2. **Profile Code**: Use PyTorch profiler 3. **Optimize Data Loading**: Prefetch, pin memory 4. **Compile Models**: Use TorchScript if possible ## Advanced Topics ### Multi-Task Learning Train single model on multiple related tasks: ```python task = tasks.PropertyPrediction( model, task=["task1", "task2", "task3"], criterion="bce", metric=["auroc"], task_weight=[1.0, 1.0, 2.0] # Weight task 3 more ) ``` ### Transfer Learning 1. Pre-train on large dataset 2. Fine-tune on target dataset 3. Optionally freeze early layers ### Self-Supervised Pre-training Use pre-training tasks: - `AttributeMasking`: Mask node features - `EdgePrediction`: Predict edge existence - `ContextPrediction`: Contrastive learning ### Custom Layers Extend TorchDrug with custom GNN layers: ```python from torchdrug import layers class CustomConv(layers.MessagePassingBase): def message(self, graph, input): # Custom message function pass def aggregate(self, graph, message): # Custom aggregation pass def combine(self, input, update): # Custom combination pass ``` ## Common Pitfalls 1. **Forgetting `input_dim` and `output_dim`**: Models won't compose 2. **Not Batching Properly**: Use PackedGraph for variable-sized graphs 3. **Data Leakage**: Be careful with scaffold splits and pre-training 4. **Ignoring Edge Features**: Bonds/spatial info can be critical 5. **Wrong Evaluation Metrics**: Match metrics to task (AUROC for imbalanced) 6. **Insufficient Regularization**: Use dropout, weight decay, early stopping 7. **Not Validating Chemistry**: Generated molecules must be valid 8. **Overfitting Small Datasets**: Use pre-training or simpler models