Initial commit
This commit is contained in:
565
skills/torchdrug/references/core_concepts.md
Normal file
565
skills/torchdrug/references/core_concepts.md
Normal file
@@ -0,0 +1,565 @@
|
||||
# Core Concepts and Technical Details
|
||||
|
||||
## Overview
|
||||
|
||||
This reference covers TorchDrug's fundamental architecture, design principles, and technical implementation details.
|
||||
|
||||
## Architecture Philosophy
|
||||
|
||||
### Modular Design
|
||||
|
||||
TorchDrug separates concerns into distinct modules:
|
||||
|
||||
1. **Representation Models** (models.py): Encode graphs into embeddings
|
||||
2. **Task Definitions** (tasks.py): Define learning objectives and evaluation
|
||||
3. **Data Handling** (data.py, datasets.py): Graph structures and datasets
|
||||
4. **Core Components** (core.py): Base classes and utilities
|
||||
|
||||
**Benefits:**
|
||||
- Reuse representations across tasks
|
||||
- Mix and match components
|
||||
- Easy experimentation and prototyping
|
||||
- Clear separation of concerns
|
||||
|
||||
### Configurable System
|
||||
|
||||
All components inherit from `core.Configurable`:
|
||||
- Serialize to configuration dictionaries
|
||||
- Reconstruct from configurations
|
||||
- Save and load complete pipelines
|
||||
- Reproducible experiments
|
||||
|
||||
## Core Components
|
||||
|
||||
### core.Configurable
|
||||
|
||||
Base class for all TorchDrug components.
|
||||
|
||||
**Key Methods:**
|
||||
- `config_dict()`: Serialize to dictionary
|
||||
- `load_config_dict(config)`: Load from dictionary
|
||||
- `save(file)`: Save to file
|
||||
- `load(file)`: Load from file
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from torchdrug import core, models
|
||||
|
||||
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
# Save configuration
|
||||
config = model.config_dict()
|
||||
# {'class': 'GIN', 'input_dim': 10, 'hidden_dims': [256, 256], ...}
|
||||
|
||||
# Reconstruct model
|
||||
model2 = core.Configurable.load_config_dict(config)
|
||||
```
|
||||
|
||||
### core.Registry
|
||||
|
||||
Decorator for registering models, tasks, and datasets.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from torchdrug import core as core_td
|
||||
|
||||
@core_td.register("models.CustomModel")
|
||||
class CustomModel(nn.Module, core_td.Configurable):
|
||||
def __init__(self, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(input_dim, hidden_dim)
|
||||
|
||||
def forward(self, graph, input, all_loss, metric):
|
||||
# Model implementation
|
||||
pass
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Models automatically serializable
|
||||
- String-based model specification
|
||||
- Easy model lookup and instantiation
|
||||
|
||||
## Data Structures
|
||||
|
||||
### Graph
|
||||
|
||||
Core data structure representing molecular or protein graphs.
|
||||
|
||||
**Attributes:**
|
||||
- `num_node`: Number of nodes
|
||||
- `num_edge`: Number of edges
|
||||
- `node_feature`: Node feature tensor [num_node, feature_dim]
|
||||
- `edge_feature`: Edge feature tensor [num_edge, feature_dim]
|
||||
- `edge_list`: Edge connectivity [num_edge, 2 or 3]
|
||||
- `num_relation`: Number of edge types (for multi-relational)
|
||||
|
||||
**Methods:**
|
||||
- `node_mask(mask)`: Select subset of nodes
|
||||
- `edge_mask(mask)`: Select subset of edges
|
||||
- `undirected()`: Make graph undirected
|
||||
- `directed()`: Make graph directed
|
||||
|
||||
**Batching:**
|
||||
- Graphs batched into single disconnected graph
|
||||
- Automatic batching in DataLoader
|
||||
- Preserves node/edge indices per graph
|
||||
|
||||
### Molecule (extends Graph)
|
||||
|
||||
Specialized graph for molecules.
|
||||
|
||||
**Additional Attributes:**
|
||||
- `atom_type`: Atomic numbers
|
||||
- `bond_type`: Bond types (single, double, triple, aromatic)
|
||||
- `formal_charge`: Atomic formal charges
|
||||
- `explicit_hs`: Explicit hydrogen counts
|
||||
|
||||
**Methods:**
|
||||
- `from_smiles(smiles)`: Create from SMILES string
|
||||
- `from_molecule(mol)`: Create from RDKit molecule
|
||||
- `to_smiles()`: Convert to SMILES
|
||||
- `to_molecule()`: Convert to RDKit molecule
|
||||
- `ion_to_molecule()`: Neutralize charges
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from torchdrug import data
|
||||
|
||||
# From SMILES
|
||||
mol = data.Molecule.from_smiles("CCO")
|
||||
|
||||
# Atom features
|
||||
print(mol.atom_type) # [6, 6, 8] (C, C, O)
|
||||
print(mol.bond_type) # [1, 1] (single bonds)
|
||||
```
|
||||
|
||||
### Protein (extends Graph)
|
||||
|
||||
Specialized graph for proteins.
|
||||
|
||||
**Additional Attributes:**
|
||||
- `residue_type`: Amino acid types
|
||||
- `atom_name`: Atom names (CA, CB, etc.)
|
||||
- `atom_type`: Atomic numbers
|
||||
- `residue_number`: Residue numbering
|
||||
- `chain_id`: Chain identifiers
|
||||
|
||||
**Methods:**
|
||||
- `from_pdb(pdb_file)`: Load from PDB file
|
||||
- `from_sequence(sequence)`: Create from sequence
|
||||
- `to_pdb(pdb_file)`: Save to PDB file
|
||||
|
||||
**Graph Construction:**
|
||||
- Nodes typically represent residues (not atoms)
|
||||
- Edges can be sequential, spatial (KNN), or contact-based
|
||||
- Configurable edge construction strategies
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from torchdrug import data
|
||||
|
||||
# Load protein
|
||||
protein = data.Protein.from_pdb("1a3x.pdb")
|
||||
|
||||
# Build graph with multiple edge types
|
||||
graph = protein.residue_graph(
|
||||
node_position="ca", # Use Cα positions
|
||||
edge_types=["sequential", "radius"] # Sequential + spatial edges
|
||||
)
|
||||
```
|
||||
|
||||
### PackedGraph
|
||||
|
||||
Efficient batching structure for heterogeneous graphs.
|
||||
|
||||
**Purpose:**
|
||||
- Batch graphs of different sizes
|
||||
- Single GPU memory allocation
|
||||
- Efficient parallel processing
|
||||
|
||||
**Attributes:**
|
||||
- `num_nodes`: List of node counts per graph
|
||||
- `num_edges`: List of edge counts per graph
|
||||
- `graph_ind`: Graph index for each node
|
||||
|
||||
**Use Cases:**
|
||||
- Automatic in DataLoader
|
||||
- Custom batching strategies
|
||||
- Multi-graph operations
|
||||
|
||||
## Model Interface
|
||||
|
||||
### Forward Function Signature
|
||||
|
||||
All TorchDrug models follow a standardized interface:
|
||||
|
||||
```python
|
||||
def forward(self, graph, input, all_loss=None, metric=None):
|
||||
"""
|
||||
Args:
|
||||
graph (Graph): Batch of graphs
|
||||
input (Tensor): Node input features
|
||||
all_loss (Tensor, optional): Accumulator for losses
|
||||
metric (dict, optional): Dictionary for metrics
|
||||
|
||||
Returns:
|
||||
dict: Output dictionary with representation keys
|
||||
"""
|
||||
# Model computation
|
||||
output = self.layers(graph, input)
|
||||
|
||||
return {
|
||||
"node_feature": output,
|
||||
"graph_feature": graph_pooling(output)
|
||||
}
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- `graph`: Batched graph structure
|
||||
- `input`: Node features [num_node, input_dim]
|
||||
- `all_loss`: Accumulated loss (for multi-task)
|
||||
- `metric`: Shared metric dictionary
|
||||
- Returns dict with representation types
|
||||
|
||||
### Essential Attributes
|
||||
|
||||
**All models must define:**
|
||||
- `input_dim`: Expected input feature dimension
|
||||
- `output_dim`: Output representation dimension
|
||||
|
||||
**Purpose:**
|
||||
- Automatic dimension checking
|
||||
- Compose models in pipelines
|
||||
- Error checking and validation
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
class CustomModel(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = hidden_dim
|
||||
# ... layers ...
|
||||
```
|
||||
|
||||
## Task Interface
|
||||
|
||||
### Core Task Methods
|
||||
|
||||
All tasks implement these methods:
|
||||
|
||||
```python
|
||||
class CustomTask(tasks.Task):
|
||||
def preprocess(self, train_set, valid_set, test_set):
|
||||
"""Dataset-specific preprocessing (optional)"""
|
||||
pass
|
||||
|
||||
def predict(self, batch):
|
||||
"""Generate predictions for a batch"""
|
||||
graph, label = batch
|
||||
output = self.model(graph, graph.node_feature)
|
||||
pred = self.mlp(output["graph_feature"])
|
||||
return pred
|
||||
|
||||
def target(self, batch):
|
||||
"""Extract ground truth labels"""
|
||||
graph, label = batch
|
||||
return label
|
||||
|
||||
def forward(self, batch):
|
||||
"""Compute training loss"""
|
||||
pred = self.predict(batch)
|
||||
target = self.target(batch)
|
||||
loss = self.criterion(pred, target)
|
||||
return loss
|
||||
|
||||
def evaluate(self, pred, target):
|
||||
"""Compute evaluation metrics"""
|
||||
metrics = {}
|
||||
metrics["auroc"] = compute_auroc(pred, target)
|
||||
metrics["auprc"] = compute_auprc(pred, target)
|
||||
return metrics
|
||||
```
|
||||
|
||||
### Task Components
|
||||
|
||||
**Typical Task Structure:**
|
||||
1. **Representation Model**: Encodes graph to embeddings
|
||||
2. **Readout/Prediction Head**: Maps embeddings to predictions
|
||||
3. **Loss Function**: Training objective
|
||||
4. **Metrics**: Evaluation measures
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from torchdrug import tasks, models
|
||||
|
||||
# Representation model
|
||||
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
# Task wraps model with prediction head
|
||||
task = tasks.PropertyPrediction(
|
||||
model=model,
|
||||
task=["task1", "task2"], # Multi-task
|
||||
criterion="bce",
|
||||
metric=["auroc", "auprc"],
|
||||
num_mlp_layer=2
|
||||
)
|
||||
```
|
||||
|
||||
## Training Workflow
|
||||
|
||||
### Standard Training Loop
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchdrug import core, models, tasks, datasets
|
||||
|
||||
# 1. Load dataset
|
||||
dataset = datasets.BBBP("~/datasets/")
|
||||
train_set, valid_set, test_set = dataset.split()
|
||||
|
||||
# 2. Create data loaders
|
||||
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
|
||||
valid_loader = DataLoader(valid_set, batch_size=32)
|
||||
|
||||
# 3. Define model and task
|
||||
model = models.GIN(input_dim=dataset.node_feature_dim,
|
||||
hidden_dims=[256, 256, 256])
|
||||
task = tasks.PropertyPrediction(model, task=dataset.tasks,
|
||||
criterion="bce", metric=["auroc", "auprc"])
|
||||
|
||||
# 4. Setup optimizer
|
||||
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
|
||||
|
||||
# 5. Training loop
|
||||
for epoch in range(100):
|
||||
# Train
|
||||
task.train()
|
||||
for batch in train_loader:
|
||||
loss = task(batch)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Validate
|
||||
task.eval()
|
||||
preds, targets = [], []
|
||||
for batch in valid_loader:
|
||||
pred = task.predict(batch)
|
||||
target = task.target(batch)
|
||||
preds.append(pred)
|
||||
targets.append(target)
|
||||
|
||||
preds = torch.cat(preds)
|
||||
targets = torch.cat(targets)
|
||||
metrics = task.evaluate(preds, targets)
|
||||
print(f"Epoch {epoch}: {metrics}")
|
||||
```
|
||||
|
||||
### PyTorch Lightning Integration
|
||||
|
||||
TorchDrug tasks are compatible with PyTorch Lightning:
|
||||
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
|
||||
class LightningWrapper(pl.LightningModule):
|
||||
def __init__(self, task):
|
||||
super().__init__()
|
||||
self.task = task
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.task(batch)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
pred = self.task.predict(batch)
|
||||
target = self.task.target(batch)
|
||||
return {"pred": pred, "target": target}
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
preds = torch.cat([o["pred"] for o in outputs])
|
||||
targets = torch.cat([o["target"] for o in outputs])
|
||||
metrics = self.task.evaluate(preds, targets)
|
||||
self.log_dict(metrics)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
```
|
||||
|
||||
## Loss Functions
|
||||
|
||||
### Built-in Criteria
|
||||
|
||||
**Classification:**
|
||||
- `"bce"`: Binary cross-entropy
|
||||
- `"ce"`: Cross-entropy (multi-class)
|
||||
|
||||
**Regression:**
|
||||
- `"mse"`: Mean squared error
|
||||
- `"mae"`: Mean absolute error
|
||||
|
||||
**Knowledge Graph:**
|
||||
- `"bce"`: Binary classification of triples
|
||||
- `"ce"`: Cross-entropy ranking loss
|
||||
- `"margin"`: Margin-based ranking
|
||||
|
||||
### Custom Loss
|
||||
|
||||
```python
|
||||
class CustomTask(tasks.Task):
|
||||
def forward(self, batch):
|
||||
pred = self.predict(batch)
|
||||
target = self.target(batch)
|
||||
|
||||
# Custom loss computation
|
||||
loss = custom_loss_function(pred, target)
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
### Common Metrics
|
||||
|
||||
**Classification:**
|
||||
- **AUROC**: Area under ROC curve
|
||||
- **AUPRC**: Area under precision-recall curve
|
||||
- **Accuracy**: Overall accuracy
|
||||
- **F1**: Harmonic mean of precision and recall
|
||||
|
||||
**Regression:**
|
||||
- **MAE**: Mean absolute error
|
||||
- **RMSE**: Root mean squared error
|
||||
- **R²**: Coefficient of determination
|
||||
- **Pearson**: Pearson correlation
|
||||
|
||||
**Ranking (Knowledge Graph):**
|
||||
- **MR**: Mean rank
|
||||
- **MRR**: Mean reciprocal rank
|
||||
- **Hits@K**: Percentage in top K
|
||||
|
||||
### Multi-Task Metrics
|
||||
|
||||
For multi-label or multi-task:
|
||||
- Metrics computed per task
|
||||
- Macro-average across tasks
|
||||
- Can weight by task importance
|
||||
|
||||
## Data Transforms
|
||||
|
||||
### Molecule Transforms
|
||||
|
||||
```python
|
||||
from torchdrug import transforms
|
||||
|
||||
# Add virtual node connected to all atoms
|
||||
transform1 = transforms.VirtualNode()
|
||||
|
||||
# Add virtual edges
|
||||
transform2 = transforms.VirtualEdge()
|
||||
|
||||
# Compose transforms
|
||||
transform = transforms.Compose([transform1, transform2])
|
||||
|
||||
dataset = datasets.BBBP("~/datasets/", transform=transform)
|
||||
```
|
||||
|
||||
### Protein Transforms
|
||||
|
||||
```python
|
||||
# Add edges based on spatial proximity
|
||||
transform = transforms.TruncateProtein(max_length=500)
|
||||
|
||||
dataset = datasets.Fold("~/datasets/", transform=transform)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Memory Efficiency
|
||||
|
||||
1. **Gradient Accumulation**: For large models
|
||||
2. **Mixed Precision**: FP16 training
|
||||
3. **Batch Size Tuning**: Balance speed and memory
|
||||
4. **Data Loading**: Multiple workers for I/O
|
||||
|
||||
### Reproducibility
|
||||
|
||||
1. **Set Seeds**: PyTorch, NumPy, Python random
|
||||
2. **Deterministic Operations**: `torch.use_deterministic_algorithms(True)`
|
||||
3. **Save Configurations**: Use `core.Configurable`
|
||||
4. **Version Control**: Track TorchDrug version
|
||||
|
||||
### Debugging
|
||||
|
||||
1. **Check Dimensions**: Verify `input_dim` and `output_dim`
|
||||
2. **Validate Batching**: Print batch statistics
|
||||
3. **Monitor Gradients**: Watch for vanishing/exploding
|
||||
4. **Overfit Small Batch**: Ensure model capacity
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
1. **GPU Utilization**: Monitor with `nvidia-smi`
|
||||
2. **Profile Code**: Use PyTorch profiler
|
||||
3. **Optimize Data Loading**: Prefetch, pin memory
|
||||
4. **Compile Models**: Use TorchScript if possible
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Multi-Task Learning
|
||||
|
||||
Train single model on multiple related tasks:
|
||||
```python
|
||||
task = tasks.PropertyPrediction(
|
||||
model,
|
||||
task=["task1", "task2", "task3"],
|
||||
criterion="bce",
|
||||
metric=["auroc"],
|
||||
task_weight=[1.0, 1.0, 2.0] # Weight task 3 more
|
||||
)
|
||||
```
|
||||
|
||||
### Transfer Learning
|
||||
|
||||
1. Pre-train on large dataset
|
||||
2. Fine-tune on target dataset
|
||||
3. Optionally freeze early layers
|
||||
|
||||
### Self-Supervised Pre-training
|
||||
|
||||
Use pre-training tasks:
|
||||
- `AttributeMasking`: Mask node features
|
||||
- `EdgePrediction`: Predict edge existence
|
||||
- `ContextPrediction`: Contrastive learning
|
||||
|
||||
### Custom Layers
|
||||
|
||||
Extend TorchDrug with custom GNN layers:
|
||||
```python
|
||||
from torchdrug import layers
|
||||
|
||||
class CustomConv(layers.MessagePassingBase):
|
||||
def message(self, graph, input):
|
||||
# Custom message function
|
||||
pass
|
||||
|
||||
def aggregate(self, graph, message):
|
||||
# Custom aggregation
|
||||
pass
|
||||
|
||||
def combine(self, input, update):
|
||||
# Custom combination
|
||||
pass
|
||||
```
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **Forgetting `input_dim` and `output_dim`**: Models won't compose
|
||||
2. **Not Batching Properly**: Use PackedGraph for variable-sized graphs
|
||||
3. **Data Leakage**: Be careful with scaffold splits and pre-training
|
||||
4. **Ignoring Edge Features**: Bonds/spatial info can be critical
|
||||
5. **Wrong Evaluation Metrics**: Match metrics to task (AUROC for imbalanced)
|
||||
6. **Insufficient Regularization**: Use dropout, weight decay, early stopping
|
||||
7. **Not Validating Chemistry**: Generated molecules must be valid
|
||||
8. **Overfitting Small Datasets**: Use pre-training or simpler models
|
||||
380
skills/torchdrug/references/datasets.md
Normal file
380
skills/torchdrug/references/datasets.md
Normal file
@@ -0,0 +1,380 @@
|
||||
# Datasets Reference
|
||||
|
||||
## Overview
|
||||
|
||||
TorchDrug provides 40+ curated datasets across multiple domains: molecular property prediction, protein modeling, knowledge graph reasoning, and retrosynthesis. All datasets support lazy loading, automatic downloading, and customizable feature extraction.
|
||||
|
||||
## Molecular Property Prediction Datasets
|
||||
|
||||
### Drug Discovery Classification
|
||||
|
||||
| Dataset | Size | Task | Classes | Description |
|
||||
|---------|------|------|---------|-------------|
|
||||
| **BACE** | 1,513 | Binary | 2 | β-secretase inhibition for Alzheimer's |
|
||||
| **BBBP** | 2,039 | Binary | 2 | Blood-brain barrier penetration |
|
||||
| **HIV** | 41,127 | Binary | 2 | Inhibition of HIV replication |
|
||||
| **ClinTox** | 1,478 | Multi-label | 2 | Clinical trial toxicity |
|
||||
| **SIDER** | 1,427 | Multi-label | 27 | Side effects by system organ class |
|
||||
| **Tox21** | 7,831 | Multi-label | 12 | Toxicity across 12 targets |
|
||||
| **ToxCast** | 8,576 | Multi-label | 617 | High-throughput toxicology |
|
||||
| **MUV** | 93,087 | Multi-label | 17 | Unbiased validation for screening |
|
||||
|
||||
**Key Features:**
|
||||
- All use scaffold splits for realistic evaluation
|
||||
- Binary classification metrics: AUROC, AUPRC
|
||||
- Multi-label handles missing values
|
||||
|
||||
**Use Cases:**
|
||||
- Drug safety prediction
|
||||
- Virtual screening
|
||||
- ADMET property prediction
|
||||
|
||||
### Drug Discovery Regression
|
||||
|
||||
| Dataset | Size | Property | Units | Description |
|
||||
|---------|------|----------|-------|-------------|
|
||||
| **ESOL** | 1,128 | Solubility | log(mol/L) | Water solubility |
|
||||
| **FreeSolv** | 642 | Hydration | kcal/mol | Hydration free energy |
|
||||
| **Lipophilicity** | 4,200 | LogD | - | Octanol/water distribution |
|
||||
| **SAMPL** | 643 | Solvation | kcal/mol | Solvation free energies |
|
||||
|
||||
**Metrics:** MAE, RMSE, R²
|
||||
**Use Cases:** ADME optimization, lead optimization
|
||||
|
||||
### Quantum Chemistry
|
||||
|
||||
| Dataset | Size | Properties | Description |
|
||||
|---------|------|------------|-------------|
|
||||
| **QM7** | 7,165 | 1 | Atomization energy |
|
||||
| **QM8** | 21,786 | 12 | Electronic spectra, excited states |
|
||||
| **QM9** | 133,885 | 12 | Geometric, energetic, electronic, thermodynamic |
|
||||
| **PCQM4M** | 3.8M | 1 | Large-scale HOMO-LUMO gap |
|
||||
|
||||
**Properties (QM9):**
|
||||
- Dipole moment
|
||||
- Isotropic polarizability
|
||||
- HOMO/LUMO energies
|
||||
- Internal energy, enthalpy, free energy
|
||||
- Heat capacity
|
||||
- Electronic spatial extent
|
||||
|
||||
**Use Cases:**
|
||||
- Quantum property prediction
|
||||
- Method development benchmarking
|
||||
- Pre-training molecular models
|
||||
|
||||
### Large Molecule Databases
|
||||
|
||||
| Dataset | Size | Description | Use Case |
|
||||
|---------|------|-------------|----------|
|
||||
| **ZINC250k** | 250,000 | Drug-like molecules | Generative model training |
|
||||
| **ZINC2M** | 2,000,000 | Drug-like molecules | Large-scale pre-training |
|
||||
| **ChEMBL** | Millions | Bioactive molecules | Property prediction, generation |
|
||||
|
||||
## Protein Datasets
|
||||
|
||||
### Function Prediction
|
||||
|
||||
| Dataset | Size | Task | Classes | Description |
|
||||
|---------|------|------|---------|-------------|
|
||||
| **EnzymeCommission** | 17,562 | Multi-class | 7 levels | EC number classification |
|
||||
| **GeneOntology** | 46,796 | Multi-label | 489 | GO term prediction (BP/MF/CC) |
|
||||
| **BetaLactamase** | 5,864 | Regression | - | Enzyme activity levels |
|
||||
| **Fluorescence** | 54,025 | Regression | - | GFP fluorescence intensity |
|
||||
| **Stability** | 53,614 | Regression | - | Thermostability (ΔΔG) |
|
||||
|
||||
**Features:**
|
||||
- Sequence and/or structure input
|
||||
- Evolutionary information available
|
||||
- Multiple train/test splits
|
||||
|
||||
**Use Cases:**
|
||||
- Protein engineering
|
||||
- Function annotation
|
||||
- Enzyme design
|
||||
|
||||
### Localization and Solubility
|
||||
|
||||
| Dataset | Size | Task | Classes | Description |
|
||||
|---------|------|------|---------|-------------|
|
||||
| **Solubility** | 62,478 | Binary | 2 | Protein solubility |
|
||||
| **BinaryLocalization** | 22,168 | Binary | 2 | Membrane vs soluble |
|
||||
| **SubcellularLocalization** | 8,943 | Multi-class | 10 | Subcellular compartment |
|
||||
|
||||
**Use Cases:**
|
||||
- Protein expression optimization
|
||||
- Target identification
|
||||
- Cell biology
|
||||
|
||||
### Structure Prediction
|
||||
|
||||
| Dataset | Size | Task | Description |
|
||||
|---------|------|------|-------------|
|
||||
| **Fold** | 16,712 | Multi-class (1,195) | Structural fold recognition |
|
||||
| **SecondaryStructure** | 8,678 | Sequence labeling | 3-state or 8-state prediction |
|
||||
| **ProteinNet** | Varied | Contact prediction | Residue-residue contacts |
|
||||
|
||||
**Use Cases:**
|
||||
- Structure prediction pipelines
|
||||
- Fold recognition
|
||||
- Contact map generation
|
||||
|
||||
### Protein Interactions
|
||||
|
||||
| Dataset | Size | Positives | Negatives | Description |
|
||||
|---------|------|-----------|-----------|-------------|
|
||||
| **HumanPPI** | 1,412 proteins | 6,584 | - | Human protein interactions |
|
||||
| **YeastPPI** | 2,018 proteins | 6,451 | - | Yeast protein interactions |
|
||||
| **PPIAffinity** | 2,156 pairs | - | - | Binding affinity values |
|
||||
|
||||
**Use Cases:**
|
||||
- PPI prediction
|
||||
- Network biology
|
||||
- Drug target identification
|
||||
|
||||
### Protein-Ligand Binding
|
||||
|
||||
| Dataset | Size | Type | Description |
|
||||
|---------|------|------|-------------|
|
||||
| **BindingDB** | ~1.5M | Affinity | Comprehensive binding data |
|
||||
| **PDBBind** | 20,000+ | 3D complexes | Structure-based binding |
|
||||
| - Refined Set | 5,316 | High quality | Curated crystal structures |
|
||||
| - Core Set | 285 | Benchmark | Diverse test set |
|
||||
|
||||
**Use Cases:**
|
||||
- Binding affinity prediction
|
||||
- Structure-based drug design
|
||||
- Scoring function development
|
||||
|
||||
### Large Protein Databases
|
||||
|
||||
| Dataset | Size | Description |
|
||||
|---------|------|-------------|
|
||||
| **AlphaFoldDB** | 200M+ | Predicted structures for most known proteins |
|
||||
| **UniProt** | Integration | Sequence and annotation data |
|
||||
|
||||
## Knowledge Graph Datasets
|
||||
|
||||
### General Knowledge
|
||||
|
||||
| Dataset | Entities | Relations | Triples | Domain |
|
||||
|---------|----------|-----------|---------|--------|
|
||||
| **FB15k** | 14,951 | 1,345 | 592,213 | Freebase (general knowledge) |
|
||||
| **FB15k-237** | 14,541 | 237 | 310,116 | Filtered Freebase |
|
||||
| **WN18** | 40,943 | 18 | 151,442 | WordNet (lexical) |
|
||||
| **WN18RR** | 40,943 | 11 | 93,003 | Filtered WordNet |
|
||||
|
||||
**Relation Types (FB15k-237):**
|
||||
- `/people/person/nationality`
|
||||
- `/film/film/genre`
|
||||
- `/location/location/contains`
|
||||
- `/business/company/founders`
|
||||
- Many more...
|
||||
|
||||
**Use Cases:**
|
||||
- Link prediction
|
||||
- Relation extraction
|
||||
- Knowledge base completion
|
||||
|
||||
### Biomedical Knowledge
|
||||
|
||||
| Dataset | Entities | Relations | Triples | Description |
|
||||
|---------|----------|-----------|---------|-------------|
|
||||
| **Hetionet** | 45,158 | 24 | 2,250,197 | Integrates 29 biomedical databases |
|
||||
|
||||
**Entity Types in Hetionet:**
|
||||
- Genes (20,945)
|
||||
- Compounds (1,552)
|
||||
- Diseases (137)
|
||||
- Anatomy (400)
|
||||
- Pathways (1,822)
|
||||
- Pharmacologic classes
|
||||
- Side effects
|
||||
- Symptoms
|
||||
- Molecular functions
|
||||
- Biological processes
|
||||
- Cellular components
|
||||
|
||||
**Relation Types:**
|
||||
- Compound-binds-Gene
|
||||
- Gene-associates-Disease
|
||||
- Disease-presents-Symptom
|
||||
- Compound-treats-Disease
|
||||
- Compound-causes-Side effect
|
||||
- Gene-participates-Pathway
|
||||
- And 18 more...
|
||||
|
||||
**Use Cases:**
|
||||
- Drug repurposing
|
||||
- Disease mechanism discovery
|
||||
- Target identification
|
||||
- Multi-hop reasoning in biomedicine
|
||||
|
||||
## Citation Network Datasets
|
||||
|
||||
| Dataset | Nodes | Edges | Classes | Description |
|
||||
|---------|-------|-------|---------|-------------|
|
||||
| **Cora** | 2,708 | 5,429 | 7 | Machine learning papers |
|
||||
| **CiteSeer** | 3,327 | 4,732 | 6 | Computer science papers |
|
||||
| **PubMed** | 19,717 | 44,338 | 3 | Biomedical papers |
|
||||
|
||||
**Use Cases:**
|
||||
- Node classification
|
||||
- GNN baseline comparisons
|
||||
- Method development
|
||||
|
||||
## Retrosynthesis Datasets
|
||||
|
||||
| Dataset | Size | Description |
|
||||
|---------|------|-------------|
|
||||
| **USPTO-50k** | 50,017 | Curated patent reactions, single-step |
|
||||
|
||||
**Features:**
|
||||
- Product → Reactants mapping
|
||||
- Atom mapping for reaction centers
|
||||
- Canonicalized SMILES
|
||||
- Balanced across reaction types
|
||||
|
||||
**Splits:**
|
||||
- Train: ~40,000
|
||||
- Validation: ~5,000
|
||||
- Test: ~5,000
|
||||
|
||||
**Use Cases:**
|
||||
- Retrosynthesis prediction
|
||||
- Reaction type classification
|
||||
- Synthetic route planning
|
||||
|
||||
## Dataset Usage Patterns
|
||||
|
||||
### Loading Datasets
|
||||
|
||||
```python
|
||||
from torchdrug import datasets
|
||||
|
||||
# Basic loading
|
||||
dataset = datasets.BBBP("~/molecule-datasets/")
|
||||
|
||||
# With transforms
|
||||
from torchdrug import transforms
|
||||
transform = transforms.VirtualNode()
|
||||
dataset = datasets.BBBP("~/molecule-datasets/", transform=transform)
|
||||
|
||||
# Protein dataset
|
||||
dataset = datasets.EnzymeCommission("~/protein-datasets/")
|
||||
|
||||
# Knowledge graph
|
||||
dataset = datasets.FB15k237("~/kg-datasets/")
|
||||
```
|
||||
|
||||
### Data Splitting
|
||||
|
||||
```python
|
||||
# Random split
|
||||
train, valid, test = dataset.split([0.8, 0.1, 0.1])
|
||||
|
||||
# Scaffold split (for molecules)
|
||||
from torchdrug import utils
|
||||
train, valid, test = dataset.split(
|
||||
utils.scaffold_split(dataset, [0.8, 0.1, 0.1])
|
||||
)
|
||||
|
||||
# Predefined splits (some datasets)
|
||||
train, valid, test = dataset.split()
|
||||
```
|
||||
|
||||
### Feature Extraction
|
||||
|
||||
**Node Features (Molecules):**
|
||||
- Atom type (one-hot or embedding)
|
||||
- Formal charge
|
||||
- Hybridization
|
||||
- Aromaticity
|
||||
- Number of hydrogens
|
||||
- Chirality
|
||||
|
||||
**Edge Features (Molecules):**
|
||||
- Bond type (single, double, triple, aromatic)
|
||||
- Stereochemistry
|
||||
- Conjugation
|
||||
- Ring membership
|
||||
|
||||
**Node Features (Proteins):**
|
||||
- Amino acid type (one-hot)
|
||||
- Physicochemical properties
|
||||
- Position in sequence
|
||||
- Secondary structure
|
||||
- Solvent accessibility
|
||||
|
||||
**Edge Features (Proteins):**
|
||||
- Edge type (sequential, spatial, contact)
|
||||
- Distance
|
||||
- Angles and dihedrals
|
||||
|
||||
## Choosing Datasets
|
||||
|
||||
### By Task
|
||||
|
||||
**Molecular Property Prediction:**
|
||||
- Start with BBBP or HIV (medium size, clear task)
|
||||
- Use QM9 for quantum properties
|
||||
- ESOL/FreeSolv for regression
|
||||
|
||||
**Protein Function:**
|
||||
- EnzymeCommission (well-defined classes)
|
||||
- GeneOntology (comprehensive annotations)
|
||||
|
||||
**Drug Safety:**
|
||||
- Tox21 (standard benchmark)
|
||||
- ClinTox (clinical relevance)
|
||||
|
||||
**Structure-Based:**
|
||||
- PDBBind (protein-ligand)
|
||||
- ProteinNet (structure prediction)
|
||||
|
||||
**Knowledge Graph:**
|
||||
- FB15k-237 (standard benchmark)
|
||||
- Hetionet (biomedical applications)
|
||||
|
||||
**Generation:**
|
||||
- ZINC250k (training)
|
||||
- QM9 (with properties)
|
||||
|
||||
**Retrosynthesis:**
|
||||
- USPTO-50k (only choice)
|
||||
|
||||
### By Size and Resources
|
||||
|
||||
**Small (<5k, for testing):**
|
||||
- BACE, FreeSolv, ClinTox
|
||||
- Core set of PDBBind
|
||||
|
||||
**Medium (5k-100k):**
|
||||
- BBBP, HIV, ESOL, Tox21
|
||||
- EnzymeCommission, Fold
|
||||
- FB15k-237, WN18RR
|
||||
|
||||
**Large (>100k):**
|
||||
- QM9, MUV, PCQM4M
|
||||
- GeneOntology, AlphaFoldDB
|
||||
- ZINC2M, BindingDB
|
||||
|
||||
### By Domain
|
||||
|
||||
**Drug Discovery:** BBBP, HIV, Tox21, ESOL, ZINC
|
||||
**Quantum Chemistry:** QM7, QM8, QM9, PCQM4M
|
||||
**Protein Engineering:** Fluorescence, Stability, Solubility
|
||||
**Structural Biology:** Fold, PDBBind, ProteinNet, AlphaFoldDB
|
||||
**Biomedical:** Hetionet, GeneOntology, EnzymeCommission
|
||||
**Retrosynthesis:** USPTO-50k
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Start Small**: Test on small datasets before scaling
|
||||
2. **Scaffold Split**: Use for realistic drug discovery evaluation
|
||||
3. **Balanced Metrics**: Use AUROC + AUPRC for imbalanced data
|
||||
4. **Multiple Runs**: Report mean ± std over multiple random seeds
|
||||
5. **Data Leakage**: Be careful with pre-trained models
|
||||
6. **Domain Knowledge**: Understand what you're predicting
|
||||
7. **Validation**: Always use held-out test set
|
||||
8. **Preprocessing**: Standardize features, handle missing values
|
||||
320
skills/torchdrug/references/knowledge_graphs.md
Normal file
320
skills/torchdrug/references/knowledge_graphs.md
Normal file
@@ -0,0 +1,320 @@
|
||||
# Knowledge Graph Reasoning
|
||||
|
||||
## Overview
|
||||
|
||||
Knowledge graphs represent structured information as entities and relations in a graph format. TorchDrug provides comprehensive support for knowledge graph completion (link prediction) using embedding-based models and neural reasoning approaches.
|
||||
|
||||
## Available Datasets
|
||||
|
||||
### General Knowledge Graphs
|
||||
|
||||
**FB15k (Freebase subset):**
|
||||
- 14,951 entities
|
||||
- 1,345 relation types
|
||||
- 592,213 triples
|
||||
- General world knowledge from Freebase
|
||||
|
||||
**FB15k-237:**
|
||||
- 14,541 entities
|
||||
- 237 relation types
|
||||
- 310,116 triples
|
||||
- Filtered version removing inverse relations
|
||||
- More challenging benchmark
|
||||
|
||||
**WN18 (WordNet):**
|
||||
- 40,943 entities (word senses)
|
||||
- 18 relation types (lexical relations)
|
||||
- 151,442 triples
|
||||
- Linguistic knowledge graph
|
||||
|
||||
**WN18RR:**
|
||||
- 40,943 entities
|
||||
- 11 relation types
|
||||
- 93,003 triples
|
||||
- Filtered WordNet removing easy inverse patterns
|
||||
|
||||
### Biomedical Knowledge Graphs
|
||||
|
||||
**Hetionet:**
|
||||
- 45,158 entities (genes, compounds, diseases, pathways, etc.)
|
||||
- 24 relation types (treats, causes, binds, etc.)
|
||||
- 2,250,197 edges
|
||||
- Integrates 29 public biomedical databases
|
||||
- Designed for drug repurposing and disease understanding
|
||||
|
||||
## Task: KnowledgeGraphCompletion
|
||||
|
||||
The primary task for knowledge graphs is link prediction - given a head entity and relation, predict the tail entity (or vice versa).
|
||||
|
||||
### Task Modes
|
||||
|
||||
**Head Prediction:**
|
||||
- Given (?, relation, tail), predict head entity
|
||||
- "What can cause Disease X?"
|
||||
|
||||
**Tail Prediction:**
|
||||
- Given (head, relation, ?), predict tail entity
|
||||
- "What diseases does Gene X cause?"
|
||||
|
||||
**Both:**
|
||||
- Predict both head and tail
|
||||
- Standard evaluation protocol
|
||||
|
||||
### Evaluation Metrics
|
||||
|
||||
**Ranking Metrics:**
|
||||
- **Mean Rank (MR)**: Average rank of correct entity
|
||||
- **Mean Reciprocal Rank (MRR)**: Average of 1/rank
|
||||
- **Hits@K**: Percentage of correct entities in top K predictions
|
||||
- Typically reported for K=1, 3, 10
|
||||
|
||||
**Filtered vs Raw:**
|
||||
- **Filtered**: Remove other known true triples from ranking
|
||||
- **Raw**: Rank among all possible entities
|
||||
- Filtered is standard for evaluation
|
||||
|
||||
## Embedding Models
|
||||
|
||||
### Translational Models
|
||||
|
||||
**TransE (Translation Embedding):**
|
||||
- Represents relations as translations in embedding space
|
||||
- h + r ≈ t (head + relation ≈ tail)
|
||||
- Simple and effective baseline
|
||||
- Works well for 1-to-1 relations
|
||||
- Struggles with N-to-N relations
|
||||
|
||||
**RotatE (Rotation Embedding):**
|
||||
- Relations as rotations in complex space
|
||||
- Better handles symmetric and inverse relations
|
||||
- State-of-the-art on many benchmarks
|
||||
- Can model composition patterns
|
||||
|
||||
### Semantic Matching Models
|
||||
|
||||
**DistMult:**
|
||||
- Bilinear scoring function
|
||||
- Handles symmetric relations naturally
|
||||
- Cannot model asymmetric relations
|
||||
- Fast and memory efficient
|
||||
|
||||
**ComplEx:**
|
||||
- Complex-valued embeddings
|
||||
- Models asymmetric and inverse relations
|
||||
- Better than DistMult for most graphs
|
||||
- Balances expressiveness and efficiency
|
||||
|
||||
**SimplE:**
|
||||
- Extends DistMult with inverse relations
|
||||
- Fully expressive (can represent any relation pattern)
|
||||
- Two embeddings per entity (canonical and inverse)
|
||||
|
||||
### Neural Logic Models
|
||||
|
||||
**NeuralLP (Neural Logic Programming):**
|
||||
- Learns logical rules through differentiable operations
|
||||
- Interprets predictions via learned rules
|
||||
- Good for sparse knowledge graphs
|
||||
- Computationally more expensive
|
||||
|
||||
**KBGAT (Knowledge Base Graph Attention):**
|
||||
- Graph attention networks for KG completion
|
||||
- Learns entity representations from neighborhood
|
||||
- Handles unseen entities through inductive learning
|
||||
- Better for incomplete graphs
|
||||
|
||||
## Training Workflow
|
||||
|
||||
### Basic Pipeline
|
||||
|
||||
```python
|
||||
from torchdrug import datasets, models, tasks, core
|
||||
|
||||
# Load dataset
|
||||
dataset = datasets.FB15k237("~/kg-datasets/")
|
||||
|
||||
# Define model
|
||||
model = models.RotatE(
|
||||
num_entity=dataset.num_entity,
|
||||
num_relation=dataset.num_relation,
|
||||
embedding_dim=2000,
|
||||
max_score=9
|
||||
)
|
||||
|
||||
# Define task
|
||||
task = tasks.KnowledgeGraphCompletion(
|
||||
model,
|
||||
num_negative=128,
|
||||
adversarial_temperature=2,
|
||||
criterion="bce"
|
||||
)
|
||||
|
||||
# Train with PyTorch Lightning or custom loop
|
||||
```
|
||||
|
||||
### Negative Sampling
|
||||
|
||||
**Strategies:**
|
||||
- **Uniform**: Sample entities uniformly at random
|
||||
- **Self-Adversarial**: Weight samples by current model's scores
|
||||
- **Type-Constrained**: Sample only valid entity types for relation
|
||||
|
||||
**Parameters:**
|
||||
- `num_negative`: Number of negative samples per positive triple
|
||||
- `adversarial_temperature`: Temperature for self-adversarial weighting
|
||||
- Higher temperature = more focus on hard negatives
|
||||
|
||||
### Loss Functions
|
||||
|
||||
**Binary Cross-Entropy (BCE):**
|
||||
- Treats each triple independently
|
||||
- Balanced classification between positive and negative
|
||||
|
||||
**Margin Loss:**
|
||||
- Ensures positive scores higher than negative by margin
|
||||
- `max(0, margin + score_neg - score_pos)`
|
||||
|
||||
**Logistic Loss:**
|
||||
- Smooth version of margin loss
|
||||
- Better gradient properties
|
||||
|
||||
## Model Selection Guide
|
||||
|
||||
### By Relation Patterns
|
||||
|
||||
**1-to-1 Relations:**
|
||||
- TransE works well
|
||||
- Any model will likely succeed
|
||||
|
||||
**1-to-N Relations:**
|
||||
- DistMult, ComplEx, SimplE
|
||||
- Avoid TransE
|
||||
|
||||
**N-to-1 Relations:**
|
||||
- DistMult, ComplEx, SimplE
|
||||
- Avoid TransE
|
||||
|
||||
**N-to-N Relations:**
|
||||
- ComplEx, SimplE, RotatE
|
||||
- Most challenging pattern
|
||||
|
||||
**Symmetric Relations:**
|
||||
- DistMult, ComplEx
|
||||
- RotatE with proper initialization
|
||||
|
||||
**Antisymmetric Relations:**
|
||||
- ComplEx, SimplE, RotatE
|
||||
- Avoid DistMult
|
||||
|
||||
**Inverse Relations:**
|
||||
- ComplEx, SimplE, RotatE
|
||||
- Important for bidirectional reasoning
|
||||
|
||||
**Composition:**
|
||||
- RotatE (best)
|
||||
- TransE (reasonable)
|
||||
- Captures multi-hop paths
|
||||
|
||||
### By Dataset Characteristics
|
||||
|
||||
**Small Graphs (< 50k entities):**
|
||||
- ComplEx or SimplE
|
||||
- Lower embedding dimensions (200-500)
|
||||
|
||||
**Large Graphs (> 100k entities):**
|
||||
- DistMult for efficiency
|
||||
- RotatE for accuracy
|
||||
- Higher dimensions (500-2000)
|
||||
|
||||
**Sparse Graphs:**
|
||||
- NeuralLP (learns rules from limited data)
|
||||
- Pre-train embeddings on larger graphs
|
||||
|
||||
**Dense, Complete Graphs:**
|
||||
- Any embedding model works well
|
||||
- Choose based on relation patterns
|
||||
|
||||
**Biomedical/Domain Graphs:**
|
||||
- Consider type constraints in sampling
|
||||
- Use domain-specific negative sampling
|
||||
- Hetionet benefits from relation-specific models
|
||||
|
||||
## Advanced Techniques
|
||||
|
||||
### Multi-Hop Reasoning
|
||||
|
||||
Chain multiple relations to answer complex queries:
|
||||
- "What drugs treat diseases caused by gene X?"
|
||||
- Requires path-based or rule-based reasoning
|
||||
- NeuralLP naturally supports this
|
||||
|
||||
### Temporal Knowledge Graphs
|
||||
|
||||
Extend to time-varying facts:
|
||||
- Add temporal information to triples
|
||||
- Predict future facts
|
||||
- Requires temporal encoding in models
|
||||
|
||||
### Few-Shot Learning
|
||||
|
||||
Handle relations with few examples:
|
||||
- Meta-learning approaches
|
||||
- Transfer from related relations
|
||||
- Important for emerging knowledge
|
||||
|
||||
### Inductive Learning
|
||||
|
||||
Generalize to unseen entities:
|
||||
- KBGAT and other GNN-based methods
|
||||
- Use entity features/descriptions
|
||||
- Critical for evolving knowledge graphs
|
||||
|
||||
## Biomedical Applications
|
||||
|
||||
### Drug Repurposing
|
||||
|
||||
Predict "drug treats disease" links in Hetionet:
|
||||
1. Train on known drug-disease associations
|
||||
2. Predict new treatment candidates
|
||||
3. Filter by mechanism (gene, pathway involvement)
|
||||
4. Validate predictions experimentally
|
||||
|
||||
### Disease Gene Discovery
|
||||
|
||||
Identify genes associated with diseases:
|
||||
1. Model gene-disease-pathway networks
|
||||
2. Predict missing gene-disease links
|
||||
3. Incorporate protein interactions, expression data
|
||||
4. Prioritize candidates for validation
|
||||
|
||||
### Protein Function Prediction
|
||||
|
||||
Link proteins to biological processes:
|
||||
1. Integrate protein interactions, GO terms
|
||||
2. Predict missing GO annotations
|
||||
3. Transfer function from similar proteins
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
**Issue: Poor performance on specific relation types**
|
||||
- Solution: Analyze relation patterns, choose appropriate model, or use relation-specific models
|
||||
|
||||
**Issue: Overfitting on small graphs**
|
||||
- Solution: Reduce embedding dimension, increase regularization, or use simpler models
|
||||
|
||||
**Issue: Slow training on large graphs**
|
||||
- Solution: Reduce negative samples, use DistMult for efficiency, or implement mini-batch training
|
||||
|
||||
**Issue: Cannot handle new entities**
|
||||
- Solution: Use inductive models (KBGAT), incorporate entity features, or pre-compute embeddings for new entities based on their neighbors
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. Start with ComplEx or RotatE for most tasks
|
||||
2. Use self-adversarial negative sampling
|
||||
3. Tune embedding dimension (typically 500-2000)
|
||||
4. Apply regularization to prevent overfitting
|
||||
5. Use filtered evaluation metrics
|
||||
6. Analyze performance per relation type
|
||||
7. Consider relation-specific models for heterogeneous graphs
|
||||
8. Validate predictions with domain experts
|
||||
541
skills/torchdrug/references/models_architectures.md
Normal file
541
skills/torchdrug/references/models_architectures.md
Normal file
@@ -0,0 +1,541 @@
|
||||
# Models and Architectures
|
||||
|
||||
## Overview
|
||||
|
||||
TorchDrug provides a comprehensive collection of pre-built model architectures for various graph-based learning tasks. This reference catalogs all available models with their characteristics, use cases, and implementation details.
|
||||
|
||||
## Graph Neural Networks
|
||||
|
||||
### GCN (Graph Convolutional Network)
|
||||
|
||||
**Type:** Spatial message passing
|
||||
**Paper:** Semi-Supervised Classification with Graph Convolutional Networks (Kipf & Welling, 2017)
|
||||
|
||||
**Characteristics:**
|
||||
- Simple and efficient aggregation
|
||||
- Normalized adjacency matrix convolution
|
||||
- Works well for homophilic graphs
|
||||
- Good baseline for many tasks
|
||||
|
||||
**Best For:**
|
||||
- Initial experiments and baselines
|
||||
- When computational efficiency is important
|
||||
- Graphs with clear local structure
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`: Node feature dimension
|
||||
- `hidden_dims`: List of hidden layer dimensions
|
||||
- `edge_input_dim`: Edge feature dimension (optional)
|
||||
- `batch_norm`: Apply batch normalization
|
||||
- `activation`: Activation function (relu, elu, etc.)
|
||||
- `dropout`: Dropout rate
|
||||
|
||||
**Use Cases:**
|
||||
- Molecular property prediction
|
||||
- Citation network classification
|
||||
- Social network analysis
|
||||
|
||||
### GAT (Graph Attention Network)
|
||||
|
||||
**Type:** Attention-based message passing
|
||||
**Paper:** Graph Attention Networks (Veličković et al., 2018)
|
||||
|
||||
**Characteristics:**
|
||||
- Learns attention weights for neighbors
|
||||
- Different importance for different neighbors
|
||||
- Multi-head attention for robustness
|
||||
- Handles varying node degrees naturally
|
||||
|
||||
**Best For:**
|
||||
- When neighbor importance varies
|
||||
- Heterogeneous graphs
|
||||
- Interpretable predictions
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `hidden_dims`: Standard dimensions
|
||||
- `num_heads`: Number of attention heads
|
||||
- `negative_slope`: LeakyReLU slope
|
||||
- `concat`: Concatenate or average multi-head outputs
|
||||
|
||||
**Use Cases:**
|
||||
- Protein-protein interaction prediction
|
||||
- Molecule generation with attention to reactive sites
|
||||
- Knowledge graph reasoning with relation importance
|
||||
|
||||
### GIN (Graph Isomorphism Network)
|
||||
|
||||
**Type:** Maximally powerful message passing
|
||||
**Paper:** How Powerful are Graph Neural Networks? (Xu et al., 2019)
|
||||
|
||||
**Characteristics:**
|
||||
- Theoretically most expressive GNN architecture
|
||||
- Injective aggregation function
|
||||
- Can distinguish graph structures GCN cannot
|
||||
- Often best performance on molecular tasks
|
||||
|
||||
**Best For:**
|
||||
- Molecular property prediction (state-of-the-art)
|
||||
- Tasks requiring structural discrimination
|
||||
- Graph classification
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `hidden_dims`: Standard dimensions
|
||||
- `edge_input_dim`: Include edge features
|
||||
- `batch_norm`: Typically use true
|
||||
- `readout`: Graph pooling ("sum", "mean", "max")
|
||||
- `eps`: Learnable or fixed epsilon
|
||||
|
||||
**Use Cases:**
|
||||
- Drug property prediction (BBBP, HIV, etc.)
|
||||
- Molecular generation
|
||||
- Reaction prediction
|
||||
|
||||
### RGCN (Relational Graph Convolutional Network)
|
||||
|
||||
**Type:** Multi-relational message passing
|
||||
**Paper:** Modeling Relational Data with Graph Convolutional Networks (Schlichtkrull et al., 2018)
|
||||
|
||||
**Characteristics:**
|
||||
- Handles multiple edge/relation types
|
||||
- Relation-specific weight matrices
|
||||
- Basis decomposition for parameter efficiency
|
||||
- Essential for knowledge graphs
|
||||
|
||||
**Best For:**
|
||||
- Knowledge graph reasoning
|
||||
- Heterogeneous molecular graphs
|
||||
- Multi-relational data
|
||||
|
||||
**Parameters:**
|
||||
- `num_relation`: Number of relation types
|
||||
- `hidden_dims`: Layer dimensions
|
||||
- `num_bases`: Basis decomposition (reduce parameters)
|
||||
|
||||
**Use Cases:**
|
||||
- Knowledge graph completion
|
||||
- Retrosynthesis (different bond types)
|
||||
- Protein interaction networks
|
||||
|
||||
### MPNN (Message Passing Neural Network)
|
||||
|
||||
**Type:** General message passing framework
|
||||
**Paper:** Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017)
|
||||
|
||||
**Characteristics:**
|
||||
- Flexible message and update functions
|
||||
- Edge features in message computation
|
||||
- GRU updates for node hidden states
|
||||
- Set2Set readout for graph representation
|
||||
|
||||
**Best For:**
|
||||
- Quantum chemistry predictions
|
||||
- Tasks with important edge information
|
||||
- When node states evolve over multiple iterations
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `hidden_dim`: Feature dimensions
|
||||
- `edge_input_dim`: Edge feature dimension
|
||||
- `num_layer`: Message passing iterations
|
||||
- `num_mlp_layer`: MLP layers in message function
|
||||
|
||||
**Use Cases:**
|
||||
- QM9 quantum property prediction
|
||||
- Molecular dynamics
|
||||
- 3D conformation-aware tasks
|
||||
|
||||
### SchNet (Continuous-Filter Convolutional Network)
|
||||
|
||||
**Type:** 3D geometry-aware convolution
|
||||
**Paper:** SchNet: A continuous-filter convolutional neural network (Schütt et al., 2017)
|
||||
|
||||
**Characteristics:**
|
||||
- Operates on 3D atomic coordinates
|
||||
- Continuous filter convolutions
|
||||
- Rotation and translation invariant
|
||||
- Excellent for quantum chemistry
|
||||
|
||||
**Best For:**
|
||||
- 3D molecular structure tasks
|
||||
- Quantum property prediction
|
||||
- Protein structure analysis
|
||||
- Energy and force prediction
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`: Atom features
|
||||
- `hidden_dims`: Layer dimensions
|
||||
- `num_gaussian`: RBF basis functions for distances
|
||||
- `cutoff`: Interaction cutoff distance
|
||||
|
||||
**Use Cases:**
|
||||
- QM9 property prediction
|
||||
- Molecular dynamics simulations
|
||||
- Protein-ligand binding with structures
|
||||
- Crystal property prediction
|
||||
|
||||
### ChebNet (Chebyshev Spectral CNN)
|
||||
|
||||
**Type:** Spectral convolution
|
||||
**Paper:** Convolutional Neural Networks on Graphs (Defferrard et al., 2016)
|
||||
|
||||
**Characteristics:**
|
||||
- Spectral graph convolution
|
||||
- Chebyshev polynomial approximation
|
||||
- Captures global graph structure
|
||||
- Computationally efficient
|
||||
|
||||
**Best For:**
|
||||
- Tasks requiring global information
|
||||
- When graph Laplacian is informative
|
||||
- Theoretical analysis
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `hidden_dims`: Dimensions
|
||||
- `num_cheb`: Order of Chebyshev polynomial
|
||||
|
||||
**Use Cases:**
|
||||
- Citation network classification
|
||||
- Brain network analysis
|
||||
- Signal processing on graphs
|
||||
|
||||
### NFP (Neural Fingerprint)
|
||||
|
||||
**Type:** Molecular fingerprint learning
|
||||
**Paper:** Convolutional Networks on Graphs for Learning Molecular Fingerprints (Duvenaud et al., 2015)
|
||||
|
||||
**Characteristics:**
|
||||
- Learns differentiable molecular fingerprints
|
||||
- Alternative to hand-crafted fingerprints (ECFP)
|
||||
- Circular convolutions like ECFP
|
||||
- Interpretable learned features
|
||||
|
||||
**Best For:**
|
||||
- Molecular similarity learning
|
||||
- Property prediction with limited data
|
||||
- When interpretability is important
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`, `output_dim`: Feature dimensions
|
||||
- `hidden_dims`: Layer dimensions
|
||||
- `num_layer`: Circular convolution depth
|
||||
|
||||
**Use Cases:**
|
||||
- Virtual screening
|
||||
- Molecular similarity search
|
||||
- QSAR modeling
|
||||
|
||||
## Protein-Specific Models
|
||||
|
||||
### GearNet (Geometry-Aware Relational Graph Network)
|
||||
|
||||
**Type:** Protein structure encoder
|
||||
**Paper:** Protein Representation Learning by Geometric Structure Pretraining (Zhang et al., 2023)
|
||||
|
||||
**Characteristics:**
|
||||
- Incorporates 3D geometric information
|
||||
- Multiple edge types (sequential, spatial, KNN)
|
||||
- Designed specifically for proteins
|
||||
- State-of-the-art on protein tasks
|
||||
|
||||
**Best For:**
|
||||
- Protein structure prediction
|
||||
- Protein function prediction
|
||||
- Protein-protein interaction
|
||||
- Any task with protein 3D structures
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`: Residue features
|
||||
- `hidden_dims`: Layer dimensions
|
||||
- `num_relation`: Edge types (sequence, radius, KNN)
|
||||
- `edge_input_dim`: Geometric features (distances, angles)
|
||||
- `batch_norm`: Typically true
|
||||
|
||||
**Use Cases:**
|
||||
- Enzyme function prediction (EnzymeCommission)
|
||||
- Protein fold recognition
|
||||
- Contact prediction
|
||||
- Binding site identification
|
||||
|
||||
### ESM (Evolutionary Scale Modeling)
|
||||
|
||||
**Type:** Protein language model (transformer)
|
||||
**Paper:** Biological structure and function emerge from scaling unsupervised learning (Rives et al., 2021)
|
||||
|
||||
**Characteristics:**
|
||||
- Pre-trained on 250M+ protein sequences
|
||||
- Captures evolutionary and structural information
|
||||
- Transformer architecture
|
||||
- Transfer learning for downstream tasks
|
||||
|
||||
**Best For:**
|
||||
- Any sequence-based protein task
|
||||
- When no structure available
|
||||
- Transfer learning with limited data
|
||||
|
||||
**Variants:**
|
||||
- ESM-1b: 650M parameters
|
||||
- ESM-2: Multiple sizes (8M to 15B parameters)
|
||||
|
||||
**Use Cases:**
|
||||
- Protein function prediction
|
||||
- Variant effect prediction
|
||||
- Protein design
|
||||
- Structure prediction (ESMFold)
|
||||
|
||||
### ProteinBERT
|
||||
|
||||
**Type:** Masked language model for proteins
|
||||
|
||||
**Characteristics:**
|
||||
- BERT-style pre-training
|
||||
- Masked amino acid prediction
|
||||
- Bidirectional context
|
||||
- Good for sequence-based tasks
|
||||
|
||||
**Use Cases:**
|
||||
- Function annotation
|
||||
- Subcellular localization
|
||||
- Stability prediction
|
||||
|
||||
### ProteinCNN / ProteinResNet
|
||||
|
||||
**Type:** Convolutional networks for sequences
|
||||
|
||||
**Characteristics:**
|
||||
- 1D convolutions on sequences
|
||||
- Local pattern recognition
|
||||
- Faster than transformers
|
||||
- Good for motif detection
|
||||
|
||||
**Use Cases:**
|
||||
- Binding site prediction
|
||||
- Secondary structure prediction
|
||||
- Domain identification
|
||||
|
||||
### ProteinLSTM
|
||||
|
||||
**Type:** Recurrent network for sequences
|
||||
|
||||
**Characteristics:**
|
||||
- Bidirectional LSTM
|
||||
- Captures long-range dependencies
|
||||
- Sequential processing
|
||||
- Good baseline for sequence tasks
|
||||
|
||||
**Use Cases:**
|
||||
- Order prediction
|
||||
- Sequential annotation
|
||||
- Time-series protein data
|
||||
|
||||
## Knowledge Graph Models
|
||||
|
||||
### TransE (Translation Embedding)
|
||||
|
||||
**Type:** Translation-based embedding
|
||||
**Paper:** Translating Embeddings for Modeling Multi-relational Data (Bordes et al., 2013)
|
||||
|
||||
**Characteristics:**
|
||||
- h + r ≈ t (head + relation ≈ tail)
|
||||
- Simple and interpretable
|
||||
- Works well for 1-to-1 relations
|
||||
- Memory efficient
|
||||
|
||||
**Best For:**
|
||||
- Large knowledge graphs
|
||||
- Initial experiments
|
||||
- Interpretable embeddings
|
||||
|
||||
**Parameters:**
|
||||
- `num_entity`, `num_relation`: Graph size
|
||||
- `embedding_dim`: Embedding dimensions (typically 50-500)
|
||||
|
||||
### RotatE (Rotation Embedding)
|
||||
|
||||
**Type:** Rotation in complex space
|
||||
**Paper:** RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (Sun et al., 2019)
|
||||
|
||||
**Characteristics:**
|
||||
- Relations as rotations in complex space
|
||||
- Handles symmetric, antisymmetric, inverse, composition
|
||||
- State-of-the-art on many benchmarks
|
||||
|
||||
**Best For:**
|
||||
- Most knowledge graph tasks
|
||||
- Complex relation patterns
|
||||
- When accuracy is critical
|
||||
|
||||
**Parameters:**
|
||||
- `num_entity`, `num_relation`: Graph size
|
||||
- `embedding_dim`: Must be even (complex embeddings)
|
||||
- `max_score`: Score clipping value
|
||||
|
||||
### DistMult
|
||||
|
||||
**Type:** Bilinear model
|
||||
|
||||
**Characteristics:**
|
||||
- Symmetric relation modeling
|
||||
- Fast and efficient
|
||||
- Cannot model antisymmetric relations
|
||||
|
||||
**Best For:**
|
||||
- Symmetric relations (e.g., "similar to")
|
||||
- When speed is critical
|
||||
- Large-scale graphs
|
||||
|
||||
### ComplEx
|
||||
|
||||
**Type:** Complex-valued embeddings
|
||||
|
||||
**Characteristics:**
|
||||
- Handles asymmetric and symmetric relations
|
||||
- Better than DistMult for most graphs
|
||||
- Good balance of expressiveness and efficiency
|
||||
|
||||
**Best For:**
|
||||
- General knowledge graph completion
|
||||
- Mixed relation types
|
||||
- When RotatE is too complex
|
||||
|
||||
### SimplE
|
||||
|
||||
**Type:** Enhanced embedding model
|
||||
|
||||
**Characteristics:**
|
||||
- Two embeddings per entity (canonical + inverse)
|
||||
- Fully expressive
|
||||
- Slightly more parameters than ComplEx
|
||||
|
||||
**Best For:**
|
||||
- When full expressiveness needed
|
||||
- Inverse relations are important
|
||||
|
||||
## Generative Models
|
||||
|
||||
### GraphAutoregressiveFlow
|
||||
|
||||
**Type:** Normalizing flow for molecules
|
||||
|
||||
**Characteristics:**
|
||||
- Exact likelihood computation
|
||||
- Invertible transformations
|
||||
- Stable training (no adversarial)
|
||||
- Conditional generation support
|
||||
|
||||
**Best For:**
|
||||
- Molecular generation
|
||||
- Density estimation
|
||||
- Interpolation between molecules
|
||||
|
||||
**Parameters:**
|
||||
- `input_dim`: Atom features
|
||||
- `hidden_dims`: Coupling layers
|
||||
- `num_flow`: Number of flow transformations
|
||||
|
||||
**Use Cases:**
|
||||
- De novo drug design
|
||||
- Chemical space exploration
|
||||
- Property-targeted generation
|
||||
|
||||
## Pre-training Models
|
||||
|
||||
### InfoGraph
|
||||
|
||||
**Type:** Contrastive learning
|
||||
|
||||
**Characteristics:**
|
||||
- Maximizes mutual information
|
||||
- Graph-level and node-level contrast
|
||||
- Unsupervised pre-training
|
||||
- Good for small datasets
|
||||
|
||||
**Use Cases:**
|
||||
- Pre-train molecular encoders
|
||||
- Few-shot learning
|
||||
- Transfer learning
|
||||
|
||||
### MultiviewContrast
|
||||
|
||||
**Type:** Multi-view contrastive learning for proteins
|
||||
|
||||
**Characteristics:**
|
||||
- Contrasts different views of proteins
|
||||
- Geometric pre-training
|
||||
- Uses 3D structure information
|
||||
- Excellent for protein models
|
||||
|
||||
**Use Cases:**
|
||||
- Pre-train GearNet on protein structures
|
||||
- Transfer to property prediction
|
||||
- Limited labeled data scenarios
|
||||
|
||||
## Model Selection Guide
|
||||
|
||||
### By Task Type
|
||||
|
||||
**Molecular Property Prediction:**
|
||||
1. GIN (first choice)
|
||||
2. GAT (interpretability)
|
||||
3. SchNet (3D available)
|
||||
|
||||
**Protein Tasks:**
|
||||
1. ESM (sequence only)
|
||||
2. GearNet (structure available)
|
||||
3. ProteinBERT (sequence, lighter than ESM)
|
||||
|
||||
**Knowledge Graphs:**
|
||||
1. RotatE (best performance)
|
||||
2. ComplEx (good balance)
|
||||
3. TransE (large graphs, efficiency)
|
||||
|
||||
**Molecular Generation:**
|
||||
1. GraphAutoregressiveFlow (exact likelihood)
|
||||
2. GCPN with GIN backbone (property optimization)
|
||||
|
||||
**Retrosynthesis:**
|
||||
1. GIN (synthon completion)
|
||||
2. RGCN (center identification with bond types)
|
||||
|
||||
### By Dataset Size
|
||||
|
||||
**Small (< 1k):**
|
||||
- Use pre-trained models (ESM for proteins)
|
||||
- Simpler architectures (GCN, ProteinCNN)
|
||||
- Heavy regularization
|
||||
|
||||
**Medium (1k-100k):**
|
||||
- GIN for molecules
|
||||
- GAT for interpretability
|
||||
- Standard training
|
||||
|
||||
**Large (> 100k):**
|
||||
- Any model works
|
||||
- Deeper architectures
|
||||
- Can train from scratch
|
||||
|
||||
### By Computational Budget
|
||||
|
||||
**Low:**
|
||||
- GCN (simplest)
|
||||
- DistMult (KG)
|
||||
- ProteinLSTM
|
||||
|
||||
**Medium:**
|
||||
- GIN
|
||||
- GAT
|
||||
- ComplEx
|
||||
|
||||
**High:**
|
||||
- ESM (large)
|
||||
- SchNet (3D)
|
||||
- RotatE with high dim
|
||||
|
||||
## Implementation Tips
|
||||
|
||||
1. **Start Simple**: Begin with GCN or GIN baseline
|
||||
2. **Use Pre-trained**: ESM for proteins, InfoGraph for molecules
|
||||
3. **Tune Depth**: 3-5 layers typically sufficient
|
||||
4. **Batch Normalization**: Usually helps (except KG embeddings)
|
||||
5. **Residual Connections**: Important for deep networks
|
||||
6. **Readout Function**: "mean" usually works well
|
||||
7. **Edge Features**: Include when available (bonds, distances)
|
||||
8. **Regularization**: Dropout, weight decay, early stopping
|
||||
352
skills/torchdrug/references/molecular_generation.md
Normal file
352
skills/torchdrug/references/molecular_generation.md
Normal file
@@ -0,0 +1,352 @@
|
||||
# Molecular Generation
|
||||
|
||||
## Overview
|
||||
|
||||
Molecular generation involves creating novel molecular structures with desired properties. TorchDrug supports both unconditional generation (exploring chemical space) and conditional generation (optimizing for specific properties).
|
||||
|
||||
## Task Types
|
||||
|
||||
### AutoregressiveGeneration
|
||||
|
||||
Generates molecules step-by-step by sequentially adding atoms and bonds. This approach enables fine-grained control and property optimization during generation.
|
||||
|
||||
**Key Features:**
|
||||
- Sequential atom-by-bond construction
|
||||
- Supports property optimization during generation
|
||||
- Can incorporate chemical validity constraints
|
||||
- Enables multi-objective optimization
|
||||
|
||||
**Generation Strategies:**
|
||||
1. **Beam Search**: Keep top-k candidates at each step
|
||||
2. **Sampling**: Probabilistic selection for diversity
|
||||
3. **Greedy**: Always select highest probability action
|
||||
|
||||
**Property Optimization:**
|
||||
- Reward shaping based on desired properties
|
||||
- Real-time constraint satisfaction
|
||||
- Multi-objective balancing (e.g., potency + drug-likeness)
|
||||
|
||||
### GCPNGeneration (Graph Convolutional Policy Network)
|
||||
|
||||
Uses reinforcement learning to generate molecules optimized for specific properties.
|
||||
|
||||
**Components:**
|
||||
1. **Policy Network**: Decides which action to take (add atom, add bond)
|
||||
2. **Reward Function**: Evaluates generated molecule quality
|
||||
3. **Training**: Reinforcement learning with policy gradient
|
||||
|
||||
**Advantages:**
|
||||
- Direct optimization of non-differentiable objectives
|
||||
- Can incorporate complex domain knowledge
|
||||
- Balances exploration and exploitation
|
||||
|
||||
**Applications:**
|
||||
- Drug design with specific targets
|
||||
- Material discovery with property constraints
|
||||
- Multi-objective molecular optimization
|
||||
|
||||
## Generative Models
|
||||
|
||||
### GraphAutoregressiveFlow
|
||||
|
||||
Normalizing flow model for molecular generation with exact likelihood computation.
|
||||
|
||||
**Architecture:**
|
||||
- Coupling layers transform simple distribution to complex molecular distribution
|
||||
- Invertible transformations enable density estimation
|
||||
- Supports conditional generation
|
||||
|
||||
**Key Features:**
|
||||
- Exact likelihood computation (vs. VAE's approximate likelihood)
|
||||
- Stable training (vs. GAN's adversarial training)
|
||||
- Efficient sampling through invertible transformations
|
||||
- Can generate molecules with specified properties
|
||||
|
||||
**Training:**
|
||||
- Maximum likelihood on molecule dataset
|
||||
- Optional property prediction head for conditional generation
|
||||
- Typically trained on ZINC or QM9
|
||||
|
||||
**Use Cases:**
|
||||
- Generating diverse drug-like molecules
|
||||
- Interpolation between known molecules
|
||||
- Density estimation for molecular space
|
||||
|
||||
## Generation Workflows
|
||||
|
||||
### Unconditional Generation
|
||||
|
||||
Generate diverse molecules without specific property targets.
|
||||
|
||||
**Workflow:**
|
||||
1. Train generative model on molecule dataset (e.g., ZINC250k)
|
||||
2. Sample from learned distribution
|
||||
3. Post-process for validity and uniqueness
|
||||
4. Evaluate diversity metrics
|
||||
|
||||
**Evaluation Metrics:**
|
||||
- **Validity**: Percentage of chemically valid molecules
|
||||
- **Uniqueness**: Percentage of unique molecules among valid
|
||||
- **Novelty**: Percentage not in training set
|
||||
- **Diversity**: Internal diversity using fingerprint similarity
|
||||
|
||||
### Conditional Generation
|
||||
|
||||
Generate molecules optimized for specific properties.
|
||||
|
||||
**Property Targets:**
|
||||
- **Drug-likeness**: LogP, QED, Lipinski's rule of five
|
||||
- **Synthesizability**: SA score, retrosynthesis feasibility
|
||||
- **Bioactivity**: Predicted IC50, binding affinity
|
||||
- **ADMET**: Absorption, distribution, metabolism, excretion, toxicity
|
||||
- **Multi-objective**: Balance multiple properties simultaneously
|
||||
|
||||
**Workflow:**
|
||||
1. Define reward function combining property objectives
|
||||
2. Train GCPN or condition flow model on properties
|
||||
3. Generate molecules with desired property ranges
|
||||
4. Validate generated molecules (in silico → wet lab)
|
||||
|
||||
### Scaffold-Based Generation
|
||||
|
||||
Generate molecules around a fixed scaffold or core structure.
|
||||
|
||||
**Applications:**
|
||||
- Lead optimization keeping core pharmacophore
|
||||
- R-group enumeration for SAR studies
|
||||
- Fragment linking and growing
|
||||
|
||||
**Approaches:**
|
||||
- Mask scaffold during training
|
||||
- Condition generation on scaffold
|
||||
- Post-generation grafting
|
||||
|
||||
### Fragment-Based Generation
|
||||
|
||||
Build molecules from validated fragments.
|
||||
|
||||
**Benefits:**
|
||||
- Ensures drug-like substructures
|
||||
- Reduces search space
|
||||
- Incorporates medicinal chemistry knowledge
|
||||
|
||||
**Methods:**
|
||||
- Fragment library as building blocks
|
||||
- Vocabulary-based generation
|
||||
- Fragment linking with learned linkers
|
||||
|
||||
## Property Optimization Strategies
|
||||
|
||||
### Single-Objective Optimization
|
||||
|
||||
Maximize or minimize a single property (e.g., binding affinity).
|
||||
|
||||
**Approach:**
|
||||
- Define scalar reward function
|
||||
- Use GCPN with RL training
|
||||
- Generate and rank candidates
|
||||
|
||||
**Challenges:**
|
||||
- May sacrifice other important properties
|
||||
- Risk of adversarial examples (valid but non-drug-like)
|
||||
- Need constraints on drug-likeness
|
||||
|
||||
### Multi-Objective Optimization
|
||||
|
||||
Balance multiple competing objectives (e.g., potency, selectivity, synthesizability).
|
||||
|
||||
**Weighting Approaches:**
|
||||
- **Linear combination**: w1×prop1 + w2×prop2 + ...
|
||||
- **Pareto optimization**: Find non-dominated solutions
|
||||
- **Constraint satisfaction**: Threshold on secondary objectives
|
||||
|
||||
**Example Objectives:**
|
||||
- High binding affinity (target)
|
||||
- Low binding affinity (off-targets)
|
||||
- High synthesizability (SA score)
|
||||
- Drug-like properties (QED)
|
||||
- Low molecular weight
|
||||
|
||||
**Workflow:**
|
||||
```python
|
||||
from torchdrug import tasks
|
||||
|
||||
# Define multi-objective reward
|
||||
def reward_function(mol):
|
||||
affinity_score = predict_binding(mol)
|
||||
druglikeness = calculate_qed(mol)
|
||||
synthesizability = sa_score(mol)
|
||||
|
||||
# Weighted combination
|
||||
reward = 0.5 * affinity_score + 0.3 * druglikeness + 0.2 * (1 - synthesizability)
|
||||
return reward
|
||||
|
||||
# GCPN task with custom reward
|
||||
task = tasks.GCPNGeneration(
|
||||
model,
|
||||
reward_function=reward_function,
|
||||
criterion="ppo" # Proximal policy optimization
|
||||
)
|
||||
```
|
||||
|
||||
### Constraint-Based Generation
|
||||
|
||||
Generate molecules satisfying hard constraints.
|
||||
|
||||
**Common Constraints:**
|
||||
- Molecular weight range
|
||||
- LogP range
|
||||
- Number of rotatable bonds
|
||||
- Ring count limits
|
||||
- Substructure inclusion/exclusion
|
||||
- Synthetic accessibility threshold
|
||||
|
||||
**Implementation:**
|
||||
- Validity checking during generation
|
||||
- Early stopping for invalid molecules
|
||||
- Penalty terms in reward function
|
||||
|
||||
## Training Considerations
|
||||
|
||||
### Dataset Selection
|
||||
|
||||
**ZINC (Drug-like compounds):**
|
||||
- ZINC250k: 250,000 compounds
|
||||
- ZINC2M: 2 million compounds
|
||||
- Pre-filtered for drug-likeness
|
||||
- Good for drug discovery applications
|
||||
|
||||
**QM9 (Small organic molecules):**
|
||||
- 133,885 molecules
|
||||
- Includes quantum properties
|
||||
- Good for property prediction models
|
||||
|
||||
**ChEMBL (Bioactive molecules):**
|
||||
- Millions of bioactive compounds
|
||||
- Activity data available
|
||||
- Target-specific generation
|
||||
|
||||
**Custom Datasets:**
|
||||
- Focus on specific chemical space
|
||||
- Include expert knowledge
|
||||
- Domain-specific constraints
|
||||
|
||||
### Data Augmentation
|
||||
|
||||
**SMILES Augmentation:**
|
||||
- Generate multiple SMILES for same molecule
|
||||
- Helps model learn canonical representations
|
||||
- Improves robustness
|
||||
|
||||
**Graph Augmentation:**
|
||||
- Random node/edge masking
|
||||
- Subgraph sampling
|
||||
- Motif substitution
|
||||
|
||||
### Model Architecture Choices
|
||||
|
||||
**For Small Molecules (<30 atoms):**
|
||||
- Simpler architectures sufficient
|
||||
- Faster training and generation
|
||||
- GCN or GIN backbone
|
||||
|
||||
**For Drug-like Molecules:**
|
||||
- Deeper architectures (4-6 layers)
|
||||
- Attention mechanisms help
|
||||
- Consider molecular fingerprints
|
||||
|
||||
**For Macrocycles/Polymers:**
|
||||
- Handle larger graphs
|
||||
- Ring closure mechanisms important
|
||||
- Long-range dependencies
|
||||
|
||||
## Validation and Filtering
|
||||
|
||||
### In Silico Validation
|
||||
|
||||
**Chemical Validity:**
|
||||
- Valence rules
|
||||
- Aromaticity rules
|
||||
- Charge neutrality
|
||||
- Stable substructures
|
||||
|
||||
**Drug-likeness Filters:**
|
||||
- Lipinski's rule of five
|
||||
- Veber's rules
|
||||
- PAINS filters (pan-assay interference compounds)
|
||||
- BRENK filters (toxic/reactive substructures)
|
||||
|
||||
**Synthesizability:**
|
||||
- SA score (synthetic accessibility)
|
||||
- Retrosynthesis prediction
|
||||
- Commercial availability of precursors
|
||||
|
||||
**Property Prediction:**
|
||||
- ADMET properties
|
||||
- Toxicity prediction
|
||||
- Off-target binding
|
||||
- Metabolic stability
|
||||
|
||||
### Ranking and Selection
|
||||
|
||||
**Criteria:**
|
||||
1. Predicted target affinity
|
||||
2. Drug-likeness score
|
||||
3. Synthesizability
|
||||
4. Novelty (dissimilarity to known actives)
|
||||
5. Diversity (within generated set)
|
||||
6. Predicted ADMET properties
|
||||
|
||||
**Selection Strategies:**
|
||||
- Pareto frontier selection
|
||||
- Weighted scoring
|
||||
- Clustering and representative selection
|
||||
- Active learning for wet lab validation
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Start Simple**: Begin with unconditional generation, then add constraints
|
||||
2. **Validate Chemistry**: Always check for valid molecules and drug-likeness
|
||||
3. **Diverse Training Data**: Use large, diverse datasets for better generalization
|
||||
4. **Multi-Objective**: Consider multiple properties from the start
|
||||
5. **Iterative Refinement**: Generate → validate → retrain with feedback
|
||||
6. **Domain Expert Review**: Consult medicinal chemists before synthesis
|
||||
7. **Benchmark**: Compare against known actives and random samples
|
||||
8. **Synthesizability**: Prioritize molecules that can actually be made
|
||||
9. **Explainability**: Understand why model generates certain structures
|
||||
10. **Wet Lab Validation**: Ultimately validate promising candidates experimentally
|
||||
|
||||
## Common Applications
|
||||
|
||||
### Drug Discovery
|
||||
- Lead generation for novel targets
|
||||
- Lead optimization around active scaffolds
|
||||
- Bioisostere replacement
|
||||
- Fragment elaboration
|
||||
|
||||
### Materials Science
|
||||
- Polymer design with target properties
|
||||
- Catalyst discovery
|
||||
- Energy storage materials
|
||||
- Photovoltaic materials
|
||||
|
||||
### Chemical Biology
|
||||
- Probe molecule design
|
||||
- Degrader (PROTAC) design
|
||||
- Molecular glue discovery
|
||||
|
||||
## Integration with Other Tools
|
||||
|
||||
**Docking:**
|
||||
- Generate molecules → Dock to target → Retrain with docking scores
|
||||
|
||||
**Retrosynthesis:**
|
||||
- Filter generated molecules by synthetic accessibility
|
||||
- Plan synthesis routes for top candidates
|
||||
|
||||
**Property Prediction:**
|
||||
- Use trained property prediction models as reward functions
|
||||
- Multi-task learning with generation and prediction
|
||||
|
||||
**Active Learning:**
|
||||
- Generate candidates → Predict properties → Synthesize best → Retrain
|
||||
169
skills/torchdrug/references/molecular_property_prediction.md
Normal file
169
skills/torchdrug/references/molecular_property_prediction.md
Normal file
@@ -0,0 +1,169 @@
|
||||
# Molecular Property Prediction
|
||||
|
||||
## Overview
|
||||
|
||||
Molecular property prediction involves predicting chemical, physical, or biological properties of molecules from their structure. TorchDrug provides comprehensive support for both classification and regression tasks on molecular graphs.
|
||||
|
||||
## Available Datasets
|
||||
|
||||
### Drug Discovery Datasets
|
||||
|
||||
**Classification Tasks:**
|
||||
- **BACE** (1,513 molecules): Binary classification for β-secretase inhibition
|
||||
- **BBBP** (2,039 molecules): Blood-brain barrier penetration prediction
|
||||
- **HIV** (41,127 molecules): Ability to inhibit HIV replication
|
||||
- **Tox21** (7,831 molecules): Toxicity prediction across 12 targets
|
||||
- **ToxCast** (8,576 molecules): Toxicology screening
|
||||
- **ClinTox** (1,478 molecules): Clinical trial toxicity
|
||||
- **SIDER** (1,427 molecules): Drug side effects (27 system organ classes)
|
||||
- **MUV** (93,087 molecules): Maximum unbiased validation for virtual screening
|
||||
|
||||
**Regression Tasks:**
|
||||
- **ESOL** (1,128 molecules): Water solubility prediction
|
||||
- **FreeSolv** (642 molecules): Hydration free energy
|
||||
- **Lipophilicity** (4,200 molecules): Octanol/water distribution coefficient
|
||||
- **SAMPL** (643 molecules): Solvation free energies
|
||||
|
||||
### Large-Scale Datasets
|
||||
|
||||
- **QM7** (7,165 molecules): Quantum mechanical properties
|
||||
- **QM8** (21,786 molecules): Electronic spectra and excited state properties
|
||||
- **QM9** (133,885 molecules): Geometric, energetic, electronic, and thermodynamic properties
|
||||
- **PCQM4M** (3,803,453 molecules): Large-scale quantum chemistry dataset
|
||||
- **ZINC250k/2M** (250k/2M molecules): Drug-like compounds for generative modeling
|
||||
|
||||
## Task Types
|
||||
|
||||
### PropertyPrediction
|
||||
|
||||
Standard task for graph-level property prediction supporting both classification and regression.
|
||||
|
||||
**Key Parameters:**
|
||||
- `model`: Graph representation model (GNN)
|
||||
- `task`: "node", "edge", or "graph" level prediction
|
||||
- `criterion`: Loss function ("mse", "bce", "ce")
|
||||
- `metric`: Evaluation metrics ("mae", "rmse", "auroc", "auprc")
|
||||
- `num_mlp_layer`: Number of MLP layers for readout
|
||||
|
||||
**Example Workflow:**
|
||||
```python
|
||||
import torch
|
||||
from torchdrug import core, models, tasks, datasets
|
||||
|
||||
# Load dataset
|
||||
dataset = datasets.BBBP("~/molecule-datasets/")
|
||||
|
||||
# Define model
|
||||
model = models.GIN(input_dim=dataset.node_feature_dim,
|
||||
hidden_dims=[256, 256, 256, 256],
|
||||
edge_input_dim=dataset.edge_feature_dim,
|
||||
batch_norm=True, readout="mean")
|
||||
|
||||
# Define task
|
||||
task = tasks.PropertyPrediction(model, task=dataset.tasks,
|
||||
criterion="bce",
|
||||
metric=("auprc", "auroc"))
|
||||
```
|
||||
|
||||
### MultipleBinaryClassification
|
||||
|
||||
Specialized task for multi-label scenarios where each molecule can have multiple binary labels (e.g., Tox21, SIDER).
|
||||
|
||||
**Key Features:**
|
||||
- Handles missing labels gracefully
|
||||
- Computes metrics per label and averaged
|
||||
- Supports weighted loss for imbalanced datasets
|
||||
|
||||
## Model Selection
|
||||
|
||||
### Recommended Models by Task
|
||||
|
||||
**Small Molecules (< 1000 molecules):**
|
||||
- GIN (Graph Isomorphism Network)
|
||||
- SchNet (for 3D structures)
|
||||
|
||||
**Medium Datasets (1k-100k molecules):**
|
||||
- GCN, GAT, or GIN
|
||||
- NFP (Neural Fingerprint)
|
||||
- MPNN (Message Passing Neural Network)
|
||||
|
||||
**Large Datasets (> 100k molecules):**
|
||||
- Pre-trained models with fine-tuning
|
||||
- InfoGraph or MultiviewContrast for self-supervised pre-training
|
||||
- GIN with deeper architectures
|
||||
|
||||
**3D Structure Available:**
|
||||
- SchNet (continuous-filter convolutions)
|
||||
- GearNet (geometry-aware relational graph)
|
||||
|
||||
## Feature Engineering
|
||||
|
||||
### Node Features
|
||||
|
||||
TorchDrug automatically extracts atom features:
|
||||
- Atom type
|
||||
- Formal charge
|
||||
- Explicit/implicit hydrogens
|
||||
- Hybridization
|
||||
- Aromaticity
|
||||
- Chirality
|
||||
|
||||
### Edge Features
|
||||
|
||||
Bond features include:
|
||||
- Bond type (single, double, triple, aromatic)
|
||||
- Stereochemistry
|
||||
- Conjugation
|
||||
- Ring membership
|
||||
|
||||
### Custom Features
|
||||
|
||||
Add custom node/edge features using transforms:
|
||||
```python
|
||||
from torchdrug import data, transforms
|
||||
|
||||
# Add custom features
|
||||
transform = transforms.VirtualNode() # Add virtual node
|
||||
dataset = datasets.BBBP("~/molecule-datasets/",
|
||||
transform=transform)
|
||||
```
|
||||
|
||||
## Training Workflow
|
||||
|
||||
### Basic Pipeline
|
||||
|
||||
1. **Load Dataset**: Choose appropriate dataset
|
||||
2. **Split Data**: Use scaffold split for drug discovery
|
||||
3. **Define Model**: Select GNN architecture
|
||||
4. **Create Task**: Configure loss and metrics
|
||||
5. **Setup Optimizer**: Adam typically works well
|
||||
6. **Train**: Use PyTorch Lightning or custom loop
|
||||
|
||||
### Data Splitting Strategies
|
||||
|
||||
**Random Split**: Standard train/val/test split
|
||||
**Scaffold Split**: Group molecules by Bemis-Murcko scaffolds (recommended for drug discovery)
|
||||
**Stratified Split**: Maintain label distribution across splits
|
||||
|
||||
### Best Practices
|
||||
|
||||
- Use scaffold splitting for realistic drug discovery evaluation
|
||||
- Apply data augmentation (virtual nodes, edges) for small datasets
|
||||
- Monitor multiple metrics (AUROC, AUPRC for classification; MAE, RMSE for regression)
|
||||
- Use early stopping based on validation performance
|
||||
- Consider ensemble methods for critical applications
|
||||
- Pre-train on large datasets before fine-tuning on small datasets
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
**Issue: Poor performance on imbalanced datasets**
|
||||
- Solution: Use weighted loss, focal loss, or over/under-sampling
|
||||
|
||||
**Issue: Overfitting on small datasets**
|
||||
- Solution: Increase regularization, use simpler models, apply data augmentation, or pre-train on larger datasets
|
||||
|
||||
**Issue: Large memory consumption**
|
||||
- Solution: Reduce batch size, use gradient accumulation, or implement graph sampling
|
||||
|
||||
**Issue: Slow training**
|
||||
- Solution: Use GPU acceleration, optimize data loading with multiple workers, or use mixed precision training
|
||||
272
skills/torchdrug/references/protein_modeling.md
Normal file
272
skills/torchdrug/references/protein_modeling.md
Normal file
@@ -0,0 +1,272 @@
|
||||
# Protein Modeling
|
||||
|
||||
## Overview
|
||||
|
||||
TorchDrug provides extensive support for protein-related tasks including sequence analysis, structure prediction, property prediction, and protein-protein interactions. Proteins are represented as graphs where nodes are amino acid residues and edges represent spatial or sequential relationships.
|
||||
|
||||
## Available Datasets
|
||||
|
||||
### Protein Function Prediction
|
||||
|
||||
**Enzyme Function:**
|
||||
- **EnzymeCommission** (17,562 proteins): EC number classification (7 levels)
|
||||
- **BetaLactamase** (5,864 sequences): Enzyme activity prediction
|
||||
|
||||
**Protein Characteristics:**
|
||||
- **Fluorescence** (54,025 sequences): GFP fluorescence intensity
|
||||
- **Stability** (53,614 sequences): Thermostability prediction
|
||||
- **Solubility** (62,478 sequences): Protein solubility classification
|
||||
- **BinaryLocalization** (22,168 proteins): Subcellular localization (membrane vs. soluble)
|
||||
- **SubcellularLocalization** (8,943 proteins): 10-class localization prediction
|
||||
|
||||
**Gene Ontology:**
|
||||
- **GeneOntology** (46,796 proteins): GO term prediction across biological process, molecular function, and cellular component
|
||||
|
||||
### Protein Structure Prediction
|
||||
|
||||
- **Fold** (16,712 proteins): Structural fold classification (1,195 classes)
|
||||
- **SecondaryStructure** (8,678 proteins): 3-state or 8-state secondary structure prediction
|
||||
- **ContactPrediction** via ProteinNet: Residue-residue contact maps
|
||||
|
||||
### Protein Interaction
|
||||
|
||||
**Protein-Protein Interactions:**
|
||||
- **HumanPPI** (1,412 proteins, 6,584 interactions): Human protein interaction network
|
||||
- **YeastPPI** (2,018 proteins, 6,451 interactions): Yeast protein interaction network
|
||||
- **PPIAffinity** (2,156 protein pairs): Binding affinity measurements
|
||||
|
||||
**Protein-Ligand Binding:**
|
||||
- **BindingDB** (~1.5M entries): Comprehensive binding affinity database
|
||||
- **PDBBind** (20,000+ complexes): 3D structure-based binding data
|
||||
- Refined set: High-quality crystal structures
|
||||
- Core set: Diverse benchmark set
|
||||
|
||||
### Large-Scale Protein Databases
|
||||
|
||||
- **AlphaFoldDB**: Access to 200M+ predicted protein structures
|
||||
- **ProteinNet**: Standardized dataset for structure prediction
|
||||
|
||||
## Task Types
|
||||
|
||||
### NodePropertyPrediction
|
||||
|
||||
Predict properties at the residue (node) level, such as secondary structure or contact maps.
|
||||
|
||||
**Use Cases:**
|
||||
- Secondary structure prediction (helix, sheet, coil)
|
||||
- Residue-level disorder prediction
|
||||
- Post-translational modification sites
|
||||
- Binding site prediction
|
||||
|
||||
### PropertyPrediction
|
||||
|
||||
Predict protein-level properties like function, stability, or localization.
|
||||
|
||||
**Use Cases:**
|
||||
- Enzyme function classification
|
||||
- Subcellular localization
|
||||
- Protein stability prediction
|
||||
- Gene ontology term prediction
|
||||
|
||||
### InteractionPrediction
|
||||
|
||||
Predict interactions between protein pairs or protein-ligand pairs.
|
||||
|
||||
**Key Features:**
|
||||
- Handles both sequence and structure inputs
|
||||
- Supports symmetric (PPI) and asymmetric (protein-ligand) interactions
|
||||
- Multiple negative sampling strategies
|
||||
|
||||
### ContactPrediction
|
||||
|
||||
Specialized task for predicting spatial proximity between residues in folded structures.
|
||||
|
||||
**Applications:**
|
||||
- Structure prediction from sequence
|
||||
- Protein folding pathway analysis
|
||||
- Validation of predicted structures
|
||||
|
||||
## Protein Representation Models
|
||||
|
||||
### Sequence-Based Models
|
||||
|
||||
**ESM (Evolutionary Scale Modeling):**
|
||||
- Pre-trained transformer model on 250M sequences
|
||||
- State-of-the-art for sequence-only tasks
|
||||
- Available in multiple sizes (ESM-1b, ESM-2)
|
||||
- Captures evolutionary and structural information
|
||||
|
||||
**ProteinBERT:**
|
||||
- BERT-style masked language model
|
||||
- Pre-trained on UniProt sequences
|
||||
- Good for transfer learning
|
||||
|
||||
**ProteinLSTM:**
|
||||
- Bidirectional LSTM for sequence encoding
|
||||
- Lightweight and fast
|
||||
- Good baseline for sequence tasks
|
||||
|
||||
**ProteinCNN / ProteinResNet:**
|
||||
- Convolutional architectures
|
||||
- Capture local sequence patterns
|
||||
- Faster than transformer models
|
||||
|
||||
### Structure-Based Models
|
||||
|
||||
**GearNet (Geometry-Aware Relational Graph Network):**
|
||||
- Incorporates 3D geometric information
|
||||
- Edge types based on sequential, radius, and K-nearest neighbors
|
||||
- State-of-the-art for structure-based tasks
|
||||
- Supports both backbone and full-atom representations
|
||||
|
||||
**GCN/GAT/GIN on Protein Graphs:**
|
||||
- Standard GNN architectures adapted for proteins
|
||||
- Flexible edge definitions (sequence, spatial, contact)
|
||||
|
||||
**SchNet:**
|
||||
- Continuous-filter convolutions
|
||||
- Handles 3D coordinates directly
|
||||
- Good for structure prediction and protein-ligand binding
|
||||
|
||||
### Feature-Based Models
|
||||
|
||||
**Statistic Features:**
|
||||
- Amino acid composition
|
||||
- Sequence length statistics
|
||||
- Motif counts
|
||||
|
||||
**Physicochemical Features:**
|
||||
- Hydrophobicity scales
|
||||
- Charge properties
|
||||
- Secondary structure propensity
|
||||
- Molecular weight, pI
|
||||
|
||||
## Protein Graph Construction
|
||||
|
||||
### Edge Types
|
||||
|
||||
**Sequential Edges:**
|
||||
- Connect adjacent residues in sequence
|
||||
- Captures primary structure
|
||||
|
||||
**Spatial Edges:**
|
||||
- K-nearest neighbors in 3D space
|
||||
- Radius cutoff (e.g., Cα atoms within 10Å)
|
||||
- Captures tertiary structure
|
||||
|
||||
**Contact Edges:**
|
||||
- Based on heavy atom distances
|
||||
- Typically < 8Å threshold
|
||||
|
||||
### Node Features
|
||||
|
||||
**Residue Identity:**
|
||||
- One-hot encoding of 20 amino acids
|
||||
- Learned embeddings
|
||||
|
||||
**Position Information:**
|
||||
- 3D coordinates (Cα, N, C, O)
|
||||
- Backbone angles (phi, psi, omega)
|
||||
- Relative spatial position encodings
|
||||
|
||||
**Physicochemical Properties:**
|
||||
- Hydrophobicity
|
||||
- Charge
|
||||
- Size
|
||||
- Secondary structure
|
||||
|
||||
## Training Workflows
|
||||
|
||||
### Pre-training Strategies
|
||||
|
||||
**Self-Supervised Pre-training:**
|
||||
- Masked residue prediction (like BERT)
|
||||
- Distance prediction between residues
|
||||
- Angle prediction (phi, psi, omega)
|
||||
- Dihedral angle prediction
|
||||
- Contact map prediction
|
||||
|
||||
**Pre-trained Model Usage:**
|
||||
```python
|
||||
from torchdrug import models
|
||||
|
||||
# Load pre-trained ESM
|
||||
model = models.ESM(path="esm1b_t33_650M_UR50S.pt")
|
||||
|
||||
# Fine-tune on downstream task
|
||||
task = tasks.PropertyPrediction(
|
||||
model, task=["stability"],
|
||||
criterion="mse", metric=["mae", "rmse"]
|
||||
)
|
||||
```
|
||||
|
||||
### Multi-Task Learning
|
||||
|
||||
Train on multiple related tasks simultaneously:
|
||||
- Joint prediction of function, localization, and stability
|
||||
- Improves generalization and data efficiency
|
||||
- Shares representations across tasks
|
||||
|
||||
### Best Practices
|
||||
|
||||
**For Sequence-Only Tasks:**
|
||||
1. Start with pre-trained ESM or ProteinBERT
|
||||
2. Fine-tune with small learning rate (1e-5 to 1e-4)
|
||||
3. Use frozen embeddings for small datasets
|
||||
4. Apply dropout for regularization
|
||||
|
||||
**For Structure-Based Tasks:**
|
||||
1. Use GearNet with multiple edge types
|
||||
2. Include geometric features (angles, dihedrals)
|
||||
3. Pre-train on large structure databases
|
||||
4. Use data augmentation (rotations, crops)
|
||||
|
||||
**For Small Datasets:**
|
||||
1. Transfer learning from pre-trained models
|
||||
2. Multi-task learning with related tasks
|
||||
3. Data augmentation (sequence mutations, structure perturbations)
|
||||
4. Strong regularization (dropout, weight decay)
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Enzyme Engineering
|
||||
- Predict enzyme activity from sequence
|
||||
- Design mutations to improve stability
|
||||
- Screen for desired catalytic properties
|
||||
|
||||
### Antibody Design
|
||||
- Predict binding affinity
|
||||
- Optimize antibody sequences
|
||||
- Predict immunogenicity
|
||||
|
||||
### Drug Target Identification
|
||||
- Predict protein function
|
||||
- Identify druggable sites
|
||||
- Analyze protein-ligand interactions
|
||||
|
||||
### Protein Structure Prediction
|
||||
- Predict secondary structure from sequence
|
||||
- Generate contact maps for tertiary structure
|
||||
- Refine AlphaFold predictions
|
||||
|
||||
## Integration with Other Tools
|
||||
|
||||
### AlphaFold Integration
|
||||
|
||||
Load AlphaFold-predicted structures:
|
||||
```python
|
||||
from torchdrug import data
|
||||
|
||||
# Load AlphaFold structure
|
||||
protein = data.Protein.from_pdb("alphafold_structure.pdb")
|
||||
|
||||
# Use in TorchDrug workflows
|
||||
```
|
||||
|
||||
### ESMFold Integration
|
||||
|
||||
Use ESMFold for structure prediction, then analyze with TorchDrug models.
|
||||
|
||||
### Rosetta/PyRosetta
|
||||
|
||||
Generate structures with Rosetta, import to TorchDrug for analysis.
|
||||
436
skills/torchdrug/references/retrosynthesis.md
Normal file
436
skills/torchdrug/references/retrosynthesis.md
Normal file
@@ -0,0 +1,436 @@
|
||||
# Retrosynthesis
|
||||
|
||||
## Overview
|
||||
|
||||
Retrosynthesis is the process of planning synthetic routes from target molecules back to commercially available starting materials. TorchDrug provides tools for learning-based retrosynthesis prediction, breaking down the complex task into manageable subtasks.
|
||||
|
||||
## Available Datasets
|
||||
|
||||
### USPTO-50K
|
||||
|
||||
The standard benchmark dataset for retrosynthesis derived from US patent literature.
|
||||
|
||||
**Statistics:**
|
||||
- 50,017 reaction examples
|
||||
- Single-step reactions
|
||||
- Filtered for quality and canonicalization
|
||||
- Contains atom mapping for reaction center identification
|
||||
|
||||
**Reaction Types:**
|
||||
- Diverse organic reactions
|
||||
- Drug-like transformations
|
||||
- Well-balanced across common reaction classes
|
||||
|
||||
**Data Splits:**
|
||||
- Training: ~40k reactions
|
||||
- Validation: ~5k reactions
|
||||
- Test: ~5k reactions
|
||||
|
||||
**Format:**
|
||||
- Product → Reactants
|
||||
- SMILES representation
|
||||
- Atom-mapped reactions for training
|
||||
|
||||
## Task Types
|
||||
|
||||
TorchDrug decomposes retrosynthesis into a multi-step pipeline:
|
||||
|
||||
### 1. CenterIdentification
|
||||
|
||||
Identifies the reaction center - which bonds were formed/broken in the forward reaction.
|
||||
|
||||
**Input:** Product molecule
|
||||
**Output:** Probability for each bond of being part of reaction center
|
||||
|
||||
**Purpose:**
|
||||
- Locate where chemistry happened
|
||||
- Guide subsequent synthon generation
|
||||
- Reduce search space dramatically
|
||||
|
||||
**Model Architecture:**
|
||||
- Graph neural network on product molecule
|
||||
- Edge-level classification
|
||||
- Attention mechanisms to highlight reactive regions
|
||||
|
||||
**Evaluation Metrics:**
|
||||
- **Top-K Accuracy**: Correct reaction center in top K predictions
|
||||
- **Bond-level F1**: Precision and recall for bond predictions
|
||||
|
||||
### 2. SynthonCompletion
|
||||
|
||||
Given the product and identified reaction center, predict the reactant structures (synthons).
|
||||
|
||||
**Input:**
|
||||
- Product molecule
|
||||
- Identified reaction center (broken/formed bonds)
|
||||
|
||||
**Output:**
|
||||
- Predicted reactant molecules (synthons)
|
||||
|
||||
**Process:**
|
||||
1. Break bonds at reaction center
|
||||
2. Modify atom environments (valence, charges)
|
||||
3. Determine leaving groups and protecting groups
|
||||
4. Generate complete reactant structures
|
||||
|
||||
**Challenges:**
|
||||
- Multiple valid reactant sets
|
||||
- Stereospecificity
|
||||
- Atom environment changes (hybridization, charge)
|
||||
- Leaving group selection
|
||||
|
||||
**Evaluation:**
|
||||
- **Exact Match**: Generated reactants exactly match ground truth
|
||||
- **Top-K Accuracy**: Correct reactants in top K predictions
|
||||
- **Chemical Validity**: Generated molecules are valid
|
||||
|
||||
### 3. Retrosynthesis (End-to-End)
|
||||
|
||||
Combines center identification and synthon completion into a unified pipeline.
|
||||
|
||||
**Input:** Target product molecule
|
||||
**Output:** Ranked list of reactant sets (synthesis pathways)
|
||||
|
||||
**Workflow:**
|
||||
1. Identify top-K reaction centers
|
||||
2. For each center, generate reactant candidates
|
||||
3. Rank combinations by model confidence
|
||||
4. Filter for commercial availability and feasibility
|
||||
|
||||
**Advantages:**
|
||||
- Single model to train and deploy
|
||||
- Joint optimization of subtasks
|
||||
- Error propagation from center identification accounted for
|
||||
|
||||
## Training Workflows
|
||||
|
||||
### Basic Pipeline
|
||||
|
||||
```python
|
||||
from torchdrug import datasets, models, tasks
|
||||
|
||||
# Load dataset
|
||||
dataset = datasets.USPTO50k("~/retro-datasets/")
|
||||
|
||||
# For center identification
|
||||
model_center = models.RGCN(
|
||||
input_dim=dataset.node_feature_dim,
|
||||
num_relation=dataset.num_bond_type,
|
||||
hidden_dims=[256, 256, 256]
|
||||
)
|
||||
|
||||
task_center = tasks.CenterIdentification(
|
||||
model_center,
|
||||
top_k=3 # Consider top 3 reaction centers
|
||||
)
|
||||
|
||||
# For synthon completion
|
||||
model_synthon = models.GIN(
|
||||
input_dim=dataset.node_feature_dim,
|
||||
hidden_dims=[256, 256, 256]
|
||||
)
|
||||
|
||||
task_synthon = tasks.SynthonCompletion(
|
||||
model_synthon,
|
||||
center_topk=3, # Use top 3 from center identification
|
||||
num_synthon_beam=5 # Beam search for synthon generation
|
||||
)
|
||||
|
||||
# End-to-end
|
||||
task_retro = tasks.Retrosynthesis(
|
||||
model=model_center,
|
||||
synthon_model=model_synthon,
|
||||
center_topk=5,
|
||||
num_synthon_beam=10
|
||||
)
|
||||
```
|
||||
|
||||
### Transfer Learning
|
||||
|
||||
Pre-train on large reaction datasets (e.g., USPTO-full with 1M+ reactions), then fine-tune on specific reaction classes.
|
||||
|
||||
**Benefits:**
|
||||
- Better generalization to rare reaction types
|
||||
- Improved performance on small datasets
|
||||
- Learn general reaction patterns
|
||||
|
||||
### Multi-Task Learning
|
||||
|
||||
Train jointly on:
|
||||
- Forward reaction prediction
|
||||
- Retrosynthesis
|
||||
- Reaction type classification
|
||||
- Yield prediction
|
||||
|
||||
**Advantages:**
|
||||
- Shared representations of chemistry
|
||||
- Better sample efficiency
|
||||
- Improved robustness
|
||||
|
||||
## Model Architectures
|
||||
|
||||
### Graph Neural Networks
|
||||
|
||||
**RGCN (Relational Graph Convolutional Network):**
|
||||
- Handles multiple bond types (single, double, triple, aromatic)
|
||||
- Edge-type-specific transformations
|
||||
- Good for reaction center identification
|
||||
|
||||
**GIN (Graph Isomorphism Network):**
|
||||
- Powerful message passing
|
||||
- Captures structural patterns
|
||||
- Works well for synthon completion
|
||||
|
||||
**GAT (Graph Attention Network):**
|
||||
- Attention weights highlight important atoms/bonds
|
||||
- Interpretable reaction center predictions
|
||||
- Flexible for various reaction types
|
||||
|
||||
### Sequence-Based Models
|
||||
|
||||
**Transformer Models:**
|
||||
- SMILES-to-SMILES translation
|
||||
- Can capture long-range dependencies
|
||||
- Require large datasets
|
||||
|
||||
**LSTM/GRU:**
|
||||
- Sequence generation for reactants
|
||||
- Autoregressive decoding
|
||||
- Good for small molecules
|
||||
|
||||
### Hybrid Approaches
|
||||
|
||||
Combine graph and sequence representations:
|
||||
- Graph encoder for products
|
||||
- Sequence decoder for reactants
|
||||
- Best of both representations
|
||||
|
||||
## Reaction Chemistry Considerations
|
||||
|
||||
### Reaction Classes
|
||||
|
||||
**Common Transformations:**
|
||||
- C-C bond formation (coupling, addition)
|
||||
- Functional group interconversions (oxidation, reduction)
|
||||
- Heterocycle synthesis (cyclizations)
|
||||
- Protection/deprotection
|
||||
- Aromatic substitutions
|
||||
|
||||
**Rare Reactions:**
|
||||
- Novel coupling methods
|
||||
- Complex rearrangements
|
||||
- Multi-component reactions
|
||||
|
||||
### Selectivity Issues
|
||||
|
||||
**Regioselectivity:**
|
||||
- Which position reacts on molecule
|
||||
- Requires understanding of electronics and sterics
|
||||
|
||||
**Stereoselectivity:**
|
||||
- Control of stereochemistry
|
||||
- Diastereoselectivity and enantioselectivity
|
||||
- Critical for drug synthesis
|
||||
|
||||
**Chemoselectivity:**
|
||||
- Which functional group reacts
|
||||
- Requires protecting group strategies
|
||||
|
||||
### Reaction Conditions
|
||||
|
||||
While TorchDrug focuses on reaction connectivity, consider:
|
||||
- Temperature and pressure
|
||||
- Catalysts and reagents
|
||||
- Solvents
|
||||
- Reaction time
|
||||
- Work-up and purification
|
||||
|
||||
## Multi-Step Synthesis Planning
|
||||
|
||||
### Single-Step Retrosynthesis
|
||||
|
||||
Predict immediate precursors for target molecule.
|
||||
|
||||
**Use Case:**
|
||||
- Late-stage transformations
|
||||
- Simple molecules (1-2 steps from commercial)
|
||||
- Initial route scouting
|
||||
|
||||
### Multi-Step Planning
|
||||
|
||||
Recursively apply retrosynthesis to each predicted reactant until reaching commercial building blocks.
|
||||
|
||||
**Tree Search Strategies:**
|
||||
|
||||
**Breadth-First Search:**
|
||||
- Explore all routes to same depth
|
||||
- Find shortest routes
|
||||
- Memory intensive
|
||||
|
||||
**Depth-First Search:**
|
||||
- Follow each route to completion
|
||||
- Memory efficient
|
||||
- May miss optimal routes
|
||||
|
||||
**Monte Carlo Tree Search (MCTS):**
|
||||
- Balance exploration and exploitation
|
||||
- Guided by model confidence
|
||||
- State-of-the-art for multi-step planning
|
||||
|
||||
**A\* Search:**
|
||||
- Heuristic-guided search
|
||||
- Optimizes for cost, complexity, or feasibility
|
||||
- Efficient for finding best routes
|
||||
|
||||
### Route Scoring
|
||||
|
||||
Rank synthetic routes by:
|
||||
1. **Number of Steps**: Fewer is better (efficiency)
|
||||
2. **Convergent vs Linear**: Convergent routes preferred
|
||||
3. **Commercial Availability**: How many steps to buyable compounds
|
||||
4. **Reaction Feasibility**: Likelihood each step works
|
||||
5. **Overall Yield**: Estimated end-to-end yield
|
||||
6. **Cost**: Reagents, labor, equipment
|
||||
7. **Green Chemistry**: Environmental impact, safety
|
||||
|
||||
### Stopping Criteria
|
||||
|
||||
Stop retrosynthesis when reaching:
|
||||
- **Commercial Compounds**: Available from vendors (e.g., Sigma-Aldrich, Enamine)
|
||||
- **Building Blocks**: Standard synthetic intermediates
|
||||
- **Max Depth**: e.g., 6-10 steps
|
||||
- **Low Confidence**: Model uncertainty too high
|
||||
|
||||
## Validation and Filtering
|
||||
|
||||
### Chemical Validity
|
||||
|
||||
Check each predicted reaction:
|
||||
- Reactants are valid molecules
|
||||
- Reaction is chemically reasonable
|
||||
- Atom mapping is consistent
|
||||
- Stoichiometry balances
|
||||
|
||||
### Synthetic Feasibility
|
||||
|
||||
**Filters:**
|
||||
- Reaction precedent (literature examples)
|
||||
- Functional group compatibility
|
||||
- Typical reaction conditions
|
||||
- Expected yield ranges
|
||||
|
||||
**Expert Systems:**
|
||||
- Rule-based validation (e.g., ARChem Route Designer)
|
||||
- Check for incompatible functional groups
|
||||
- Identify protection/deprotection needs
|
||||
|
||||
### Commercial Availability
|
||||
|
||||
**Databases:**
|
||||
- eMolecules: 10M+ commercial compounds
|
||||
- ZINC: Annotated with vendor info
|
||||
- Reaxys: Commercially available building blocks
|
||||
|
||||
**Considerations:**
|
||||
- Cost per gram
|
||||
- Purity and quality
|
||||
- Lead time for delivery
|
||||
- Minimum order quantities
|
||||
|
||||
## Integration with Other Tools
|
||||
|
||||
### Reaction Prediction (Forward)
|
||||
|
||||
Train forward reaction prediction models to validate retrosynthetic proposals:
|
||||
- Predict products from proposed reactants
|
||||
- Validate reaction feasibility
|
||||
- Estimate yields
|
||||
|
||||
### Retrosynthesis Software
|
||||
|
||||
**Integration with:**
|
||||
- SciFinder (CAS)
|
||||
- Reaxys (Elsevier)
|
||||
- ARChem Route Designer
|
||||
- IBM RXN for Chemistry
|
||||
|
||||
**TorchDrug as Component:**
|
||||
- Use TorchDrug models within larger planning systems
|
||||
- Combine ML predictions with rule-based systems
|
||||
- Hybrid AI + expert system approaches
|
||||
|
||||
### Experimental Validation
|
||||
|
||||
**High-Throughput Screening:**
|
||||
- Rapid testing of predicted reactions
|
||||
- Automated synthesis platforms
|
||||
- Feedback loop to improve models
|
||||
|
||||
**Robotic Synthesis:**
|
||||
- Automated execution of planned routes
|
||||
- Real-time optimization
|
||||
- Data generation for model improvement
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Ensemble Predictions**: Use multiple models for robustness
|
||||
2. **Reaction Validation**: Always validate with chemistry rules
|
||||
3. **Commercial Check**: Verify building block availability early
|
||||
4. **Diversity**: Generate multiple diverse routes, not just top-1
|
||||
5. **Expert Review**: Have chemists evaluate proposed routes
|
||||
6. **Literature Search**: Check for precedents of key steps
|
||||
7. **Iterative Refinement**: Update models with experimental feedback
|
||||
8. **Interpretability**: Understand why model suggests each step
|
||||
9. **Edge Cases**: Handle unusual functional groups and scaffolds
|
||||
10. **Benchmarking**: Compare against known synthesis routes
|
||||
|
||||
## Common Applications
|
||||
|
||||
### Drug Synthesis Planning
|
||||
|
||||
- Small molecule drugs
|
||||
- Natural product total synthesis
|
||||
- Late-stage functionalization strategies
|
||||
|
||||
### Library Enumeration
|
||||
|
||||
- Virtual library design
|
||||
- Retrosynthetic filtering of generated molecules
|
||||
- Prioritize synthesizable compounds
|
||||
|
||||
### Process Chemistry
|
||||
|
||||
- Route scouting for large-scale synthesis
|
||||
- Cost optimization
|
||||
- Green chemistry alternatives
|
||||
|
||||
### Synthetic Method Development
|
||||
|
||||
- Identify gaps in synthetic methodology
|
||||
- Guide development of new reactions
|
||||
- Expand retrosynthesis model capabilities
|
||||
|
||||
## Challenges and Future Directions
|
||||
|
||||
### Current Limitations
|
||||
|
||||
- Limited to single-step predictions (most models)
|
||||
- Doesn't consider reaction conditions explicitly
|
||||
- Stereochemistry handling is challenging
|
||||
- Rare reaction types underrepresented
|
||||
|
||||
### Active Research Areas
|
||||
|
||||
- End-to-end multi-step planning
|
||||
- Incorporation of reaction conditions
|
||||
- Stereoselective retrosynthesis
|
||||
- Integration with robotics for closed-loop optimization
|
||||
- Semi-template methods (balance templates and templates-free)
|
||||
- Uncertainty quantification for predictions
|
||||
|
||||
### Emerging Techniques
|
||||
|
||||
- Large language models for chemistry (ChemGPT, MolT5)
|
||||
- Reinforcement learning for route optimization
|
||||
- Graph transformers for long-range interactions
|
||||
- Self-supervised pre-training on reaction databases
|
||||
Reference in New Issue
Block a user