Initial commit
This commit is contained in:
338
skills/deepchem/scripts/graph_neural_network.py
Normal file
338
skills/deepchem/scripts/graph_neural_network.py
Normal file
@@ -0,0 +1,338 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Graph Neural Network Training Script
|
||||
|
||||
This script demonstrates training Graph Convolutional Networks (GCNs) and other
|
||||
graph-based models for molecular property prediction.
|
||||
|
||||
Usage:
|
||||
python graph_neural_network.py --dataset tox21 --model gcn
|
||||
python graph_neural_network.py --dataset bbbp --model attentivefp
|
||||
python graph_neural_network.py --data custom.csv --task-type regression
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import deepchem as dc
|
||||
import sys
|
||||
|
||||
|
||||
AVAILABLE_MODELS = {
|
||||
'gcn': 'Graph Convolutional Network',
|
||||
'gat': 'Graph Attention Network',
|
||||
'attentivefp': 'Attentive Fingerprint',
|
||||
'mpnn': 'Message Passing Neural Network',
|
||||
'dmpnn': 'Directed Message Passing Neural Network'
|
||||
}
|
||||
|
||||
MOLNET_DATASETS = {
|
||||
'tox21': ('classification', 12),
|
||||
'bbbp': ('classification', 1),
|
||||
'bace': ('classification', 1),
|
||||
'hiv': ('classification', 1),
|
||||
'delaney': ('regression', 1),
|
||||
'freesolv': ('regression', 1),
|
||||
'lipo': ('regression', 1)
|
||||
}
|
||||
|
||||
|
||||
def create_model(model_type, n_tasks, mode='classification'):
|
||||
"""
|
||||
Create a graph neural network model.
|
||||
|
||||
Args:
|
||||
model_type: Type of model ('gcn', 'gat', 'attentivefp', etc.)
|
||||
n_tasks: Number of prediction tasks
|
||||
mode: 'classification' or 'regression'
|
||||
|
||||
Returns:
|
||||
DeepChem model
|
||||
"""
|
||||
if model_type == 'gcn':
|
||||
return dc.models.GCNModel(
|
||||
n_tasks=n_tasks,
|
||||
mode=mode,
|
||||
batch_size=128,
|
||||
learning_rate=0.001,
|
||||
dropout=0.0
|
||||
)
|
||||
elif model_type == 'gat':
|
||||
return dc.models.GATModel(
|
||||
n_tasks=n_tasks,
|
||||
mode=mode,
|
||||
batch_size=128,
|
||||
learning_rate=0.001
|
||||
)
|
||||
elif model_type == 'attentivefp':
|
||||
return dc.models.AttentiveFPModel(
|
||||
n_tasks=n_tasks,
|
||||
mode=mode,
|
||||
batch_size=128,
|
||||
learning_rate=0.001
|
||||
)
|
||||
elif model_type == 'mpnn':
|
||||
return dc.models.MPNNModel(
|
||||
n_tasks=n_tasks,
|
||||
mode=mode,
|
||||
batch_size=128,
|
||||
learning_rate=0.001
|
||||
)
|
||||
elif model_type == 'dmpnn':
|
||||
return dc.models.DMPNNModel(
|
||||
n_tasks=n_tasks,
|
||||
mode=mode,
|
||||
batch_size=128,
|
||||
learning_rate=0.001
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
|
||||
def train_on_molnet(dataset_name, model_type, n_epochs=50):
|
||||
"""
|
||||
Train a graph neural network on a MoleculeNet benchmark dataset.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of MoleculeNet dataset
|
||||
model_type: Type of model to train
|
||||
n_epochs: Number of training epochs
|
||||
|
||||
Returns:
|
||||
Trained model and test scores
|
||||
"""
|
||||
print("=" * 70)
|
||||
print(f"Training {AVAILABLE_MODELS[model_type]} on {dataset_name.upper()}")
|
||||
print("=" * 70)
|
||||
|
||||
# Get dataset info
|
||||
task_type, n_tasks_default = MOLNET_DATASETS[dataset_name]
|
||||
|
||||
# Load dataset with graph featurization
|
||||
print(f"\nLoading {dataset_name} dataset with GraphConv featurizer...")
|
||||
load_func = getattr(dc.molnet, f'load_{dataset_name}')
|
||||
tasks, datasets, transformers = load_func(
|
||||
featurizer='GraphConv',
|
||||
splitter='scaffold'
|
||||
)
|
||||
train, valid, test = datasets
|
||||
|
||||
n_tasks = len(tasks)
|
||||
print(f"\nDataset Information:")
|
||||
print(f" Task type: {task_type}")
|
||||
print(f" Number of tasks: {n_tasks}")
|
||||
print(f" Training samples: {len(train)}")
|
||||
print(f" Validation samples: {len(valid)}")
|
||||
print(f" Test samples: {len(test)}")
|
||||
|
||||
# Create model
|
||||
print(f"\nCreating {AVAILABLE_MODELS[model_type]} model...")
|
||||
model = create_model(model_type, n_tasks, mode=task_type)
|
||||
|
||||
# Train
|
||||
print(f"\nTraining for {n_epochs} epochs...")
|
||||
model.fit(train, nb_epoch=n_epochs)
|
||||
print("Training complete!")
|
||||
|
||||
# Evaluate
|
||||
print("\n" + "=" * 70)
|
||||
print("Model Evaluation")
|
||||
print("=" * 70)
|
||||
|
||||
if task_type == 'classification':
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
|
||||
dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
|
||||
dc.metrics.Metric(dc.metrics.f1_score, name='F1'),
|
||||
]
|
||||
else:
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
|
||||
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
|
||||
dc.metrics.Metric(dc.metrics.root_mean_squared_error, name='RMSE'),
|
||||
]
|
||||
|
||||
results = {}
|
||||
for dataset_name_eval, dataset in [('Train', train), ('Valid', valid), ('Test', test)]:
|
||||
print(f"\n{dataset_name_eval} Set:")
|
||||
scores = model.evaluate(dataset, metrics)
|
||||
results[dataset_name_eval] = scores
|
||||
for metric_name, score in scores.items():
|
||||
print(f" {metric_name}: {score:.4f}")
|
||||
|
||||
return model, results
|
||||
|
||||
|
||||
def train_on_custom_data(data_path, model_type, task_type, target_cols, smiles_col='smiles', n_epochs=50):
|
||||
"""
|
||||
Train a graph neural network on custom CSV data.
|
||||
|
||||
Args:
|
||||
data_path: Path to CSV file
|
||||
model_type: Type of model to train
|
||||
task_type: 'classification' or 'regression'
|
||||
target_cols: List of target column names
|
||||
smiles_col: Name of SMILES column
|
||||
n_epochs: Number of training epochs
|
||||
|
||||
Returns:
|
||||
Trained model and test dataset
|
||||
"""
|
||||
print("=" * 70)
|
||||
print(f"Training {AVAILABLE_MODELS[model_type]} on Custom Data")
|
||||
print("=" * 70)
|
||||
|
||||
# Load and featurize data
|
||||
print(f"\nLoading data from {data_path}...")
|
||||
featurizer = dc.feat.MolGraphConvFeaturizer()
|
||||
loader = dc.data.CSVLoader(
|
||||
tasks=target_cols,
|
||||
feature_field=smiles_col,
|
||||
featurizer=featurizer
|
||||
)
|
||||
dataset = loader.create_dataset(data_path)
|
||||
|
||||
print(f"Loaded {len(dataset)} molecules")
|
||||
|
||||
# Split data
|
||||
print("\nSplitting data with scaffold splitter...")
|
||||
splitter = dc.splits.ScaffoldSplitter()
|
||||
train, valid, test = splitter.train_valid_test_split(
|
||||
dataset,
|
||||
frac_train=0.8,
|
||||
frac_valid=0.1,
|
||||
frac_test=0.1
|
||||
)
|
||||
|
||||
print(f" Training: {len(train)}")
|
||||
print(f" Validation: {len(valid)}")
|
||||
print(f" Test: {len(test)}")
|
||||
|
||||
# Create model
|
||||
print(f"\nCreating {AVAILABLE_MODELS[model_type]} model...")
|
||||
n_tasks = len(target_cols)
|
||||
model = create_model(model_type, n_tasks, mode=task_type)
|
||||
|
||||
# Train
|
||||
print(f"\nTraining for {n_epochs} epochs...")
|
||||
model.fit(train, nb_epoch=n_epochs)
|
||||
print("Training complete!")
|
||||
|
||||
# Evaluate
|
||||
print("\n" + "=" * 70)
|
||||
print("Model Evaluation")
|
||||
print("=" * 70)
|
||||
|
||||
if task_type == 'classification':
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
|
||||
dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
|
||||
]
|
||||
else:
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
|
||||
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
|
||||
]
|
||||
|
||||
for dataset_name, dataset in [('Train', train), ('Valid', valid), ('Test', test)]:
|
||||
print(f"\n{dataset_name} Set:")
|
||||
scores = model.evaluate(dataset, metrics)
|
||||
for metric_name, score in scores.items():
|
||||
print(f" {metric_name}: {score:.4f}")
|
||||
|
||||
return model, test
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train graph neural networks for molecular property prediction'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
choices=list(AVAILABLE_MODELS.keys()),
|
||||
default='gcn',
|
||||
help='Type of graph neural network model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
choices=list(MOLNET_DATASETS.keys()),
|
||||
default=None,
|
||||
help='MoleculeNet dataset to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--data',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to custom CSV file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--task-type',
|
||||
type=str,
|
||||
choices=['classification', 'regression'],
|
||||
default='classification',
|
||||
help='Type of prediction task (for custom data)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--targets',
|
||||
nargs='+',
|
||||
default=['target'],
|
||||
help='Names of target columns (for custom data)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--smiles-col',
|
||||
type=str,
|
||||
default='smiles',
|
||||
help='Name of SMILES column'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=50,
|
||||
help='Number of training epochs'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if args.dataset is None and args.data is None:
|
||||
print("Error: Must specify either --dataset (MoleculeNet) or --data (custom CSV)",
|
||||
file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if args.dataset and args.data:
|
||||
print("Error: Cannot specify both --dataset and --data",
|
||||
file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Train model
|
||||
try:
|
||||
if args.dataset:
|
||||
model, results = train_on_molnet(
|
||||
args.dataset,
|
||||
args.model,
|
||||
n_epochs=args.epochs
|
||||
)
|
||||
else:
|
||||
model, test_set = train_on_custom_data(
|
||||
args.data,
|
||||
args.model,
|
||||
args.task_type,
|
||||
args.targets,
|
||||
smiles_col=args.smiles_col,
|
||||
n_epochs=args.epochs
|
||||
)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Training Complete!")
|
||||
print("=" * 70)
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
224
skills/deepchem/scripts/predict_solubility.py
Normal file
224
skills/deepchem/scripts/predict_solubility.py
Normal file
@@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Molecular Solubility Prediction Script
|
||||
|
||||
This script trains a model to predict aqueous solubility from SMILES strings
|
||||
using the Delaney (ESOL) dataset as an example. Can be adapted for custom datasets.
|
||||
|
||||
Usage:
|
||||
python predict_solubility.py --data custom_data.csv --smiles-col smiles --target-col solubility
|
||||
python predict_solubility.py # Uses Delaney dataset by default
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import deepchem as dc
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
|
||||
def train_solubility_model(data_path=None, smiles_col='smiles', target_col='measured log solubility in mols per litre'):
|
||||
"""
|
||||
Train a solubility prediction model.
|
||||
|
||||
Args:
|
||||
data_path: Path to CSV file with SMILES and solubility data. If None, uses Delaney dataset.
|
||||
smiles_col: Name of column containing SMILES strings
|
||||
target_col: Name of column containing solubility values
|
||||
|
||||
Returns:
|
||||
Trained model, test dataset, and transformers
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("DeepChem Solubility Prediction")
|
||||
print("=" * 60)
|
||||
|
||||
# Load data
|
||||
if data_path is None:
|
||||
print("\nUsing Delaney (ESOL) benchmark dataset...")
|
||||
tasks, datasets, transformers = dc.molnet.load_delaney(
|
||||
featurizer='ECFP',
|
||||
splitter='scaffold'
|
||||
)
|
||||
train, valid, test = datasets
|
||||
else:
|
||||
print(f"\nLoading custom data from {data_path}...")
|
||||
featurizer = dc.feat.CircularFingerprint(radius=2, size=2048)
|
||||
loader = dc.data.CSVLoader(
|
||||
tasks=[target_col],
|
||||
feature_field=smiles_col,
|
||||
featurizer=featurizer
|
||||
)
|
||||
dataset = loader.create_dataset(data_path)
|
||||
|
||||
# Split data
|
||||
print("Splitting data with scaffold splitter...")
|
||||
splitter = dc.splits.ScaffoldSplitter()
|
||||
train, valid, test = splitter.train_valid_test_split(
|
||||
dataset,
|
||||
frac_train=0.8,
|
||||
frac_valid=0.1,
|
||||
frac_test=0.1
|
||||
)
|
||||
|
||||
# Normalize data
|
||||
print("Normalizing features and targets...")
|
||||
transformers = [
|
||||
dc.trans.NormalizationTransformer(
|
||||
transform_y=True,
|
||||
dataset=train
|
||||
)
|
||||
]
|
||||
for transformer in transformers:
|
||||
train = transformer.transform(train)
|
||||
valid = transformer.transform(valid)
|
||||
test = transformer.transform(test)
|
||||
|
||||
tasks = [target_col]
|
||||
|
||||
print(f"\nDataset sizes:")
|
||||
print(f" Training: {len(train)} molecules")
|
||||
print(f" Validation: {len(valid)} molecules")
|
||||
print(f" Test: {len(test)} molecules")
|
||||
|
||||
# Create model
|
||||
print("\nCreating multitask regressor...")
|
||||
model = dc.models.MultitaskRegressor(
|
||||
n_tasks=len(tasks),
|
||||
n_features=2048, # ECFP fingerprint size
|
||||
layer_sizes=[1000, 500],
|
||||
dropouts=0.25,
|
||||
learning_rate=0.001,
|
||||
batch_size=50
|
||||
)
|
||||
|
||||
# Train model
|
||||
print("\nTraining model...")
|
||||
model.fit(train, nb_epoch=50)
|
||||
print("Training complete!")
|
||||
|
||||
# Evaluate model
|
||||
print("\n" + "=" * 60)
|
||||
print("Model Evaluation")
|
||||
print("=" * 60)
|
||||
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
|
||||
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
|
||||
dc.metrics.Metric(dc.metrics.root_mean_squared_error, name='RMSE'),
|
||||
]
|
||||
|
||||
for dataset_name, dataset in [('Train', train), ('Valid', valid), ('Test', test)]:
|
||||
print(f"\n{dataset_name} Set:")
|
||||
scores = model.evaluate(dataset, metrics)
|
||||
for metric_name, score in scores.items():
|
||||
print(f" {metric_name}: {score:.4f}")
|
||||
|
||||
return model, test, transformers
|
||||
|
||||
|
||||
def predict_new_molecules(model, smiles_list, transformers=None):
|
||||
"""
|
||||
Predict solubility for new molecules.
|
||||
|
||||
Args:
|
||||
model: Trained DeepChem model
|
||||
smiles_list: List of SMILES strings
|
||||
transformers: List of data transformers to apply
|
||||
|
||||
Returns:
|
||||
Array of predictions
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Predicting New Molecules")
|
||||
print("=" * 60)
|
||||
|
||||
# Featurize new molecules
|
||||
featurizer = dc.feat.CircularFingerprint(radius=2, size=2048)
|
||||
features = featurizer.featurize(smiles_list)
|
||||
|
||||
# Create dataset
|
||||
new_dataset = dc.data.NumpyDataset(X=features)
|
||||
|
||||
# Apply transformers (if any)
|
||||
if transformers:
|
||||
for transformer in transformers:
|
||||
new_dataset = transformer.transform(new_dataset)
|
||||
|
||||
# Predict
|
||||
predictions = model.predict(new_dataset)
|
||||
|
||||
# Display results
|
||||
print("\nPredictions:")
|
||||
for smiles, pred in zip(smiles_list, predictions):
|
||||
print(f" {smiles:30s} -> {pred[0]:.3f} log(mol/L)")
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train a molecular solubility prediction model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--data',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to CSV file with molecular data'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--smiles-col',
|
||||
type=str,
|
||||
default='smiles',
|
||||
help='Name of column containing SMILES strings'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--target-col',
|
||||
type=str,
|
||||
default='solubility',
|
||||
help='Name of column containing target values'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--predict',
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='SMILES strings to predict after training'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Train model
|
||||
try:
|
||||
model, test_set, transformers = train_solubility_model(
|
||||
data_path=args.data,
|
||||
smiles_col=args.smiles_col,
|
||||
target_col=args.target_col
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\nError during training: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Make predictions on new molecules if provided
|
||||
if args.predict:
|
||||
try:
|
||||
predict_new_molecules(model, args.predict, transformers)
|
||||
except Exception as e:
|
||||
print(f"\nError during prediction: {e}", file=sys.stderr)
|
||||
return 1
|
||||
else:
|
||||
# Example predictions
|
||||
example_smiles = [
|
||||
'CCO', # Ethanol
|
||||
'CC(=O)O', # Acetic acid
|
||||
'c1ccccc1', # Benzene
|
||||
'CN1C=NC2=C1C(=O)N(C(=O)N2C)C', # Caffeine
|
||||
]
|
||||
predict_new_molecules(model, example_smiles, transformers)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Complete!")
|
||||
print("=" * 60)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
375
skills/deepchem/scripts/transfer_learning.py
Normal file
375
skills/deepchem/scripts/transfer_learning.py
Normal file
@@ -0,0 +1,375 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Transfer Learning Script for DeepChem
|
||||
|
||||
Use pretrained models (ChemBERTa, GROVER, MolFormer) for molecular property prediction
|
||||
with transfer learning. Particularly useful for small datasets.
|
||||
|
||||
Usage:
|
||||
python transfer_learning.py --model chemberta --data my_data.csv --target activity
|
||||
python transfer_learning.py --model grover --dataset bbbp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import deepchem as dc
|
||||
import sys
|
||||
|
||||
|
||||
PRETRAINED_MODELS = {
|
||||
'chemberta': {
|
||||
'name': 'ChemBERTa',
|
||||
'description': 'BERT pretrained on 77M molecules from ZINC15',
|
||||
'model_id': 'seyonec/ChemBERTa-zinc-base-v1'
|
||||
},
|
||||
'grover': {
|
||||
'name': 'GROVER',
|
||||
'description': 'Graph transformer pretrained on 10M molecules',
|
||||
'model_id': None # GROVER uses its own loading mechanism
|
||||
},
|
||||
'molformer': {
|
||||
'name': 'MolFormer',
|
||||
'description': 'Transformer pretrained on molecular structures',
|
||||
'model_id': 'ibm/MoLFormer-XL-both-10pct'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def train_chemberta(train_dataset, valid_dataset, test_dataset, task_type='classification', n_tasks=1, n_epochs=10):
|
||||
"""
|
||||
Fine-tune ChemBERTa on a dataset.
|
||||
|
||||
Args:
|
||||
train_dataset: Training dataset
|
||||
valid_dataset: Validation dataset
|
||||
test_dataset: Test dataset
|
||||
task_type: 'classification' or 'regression'
|
||||
n_tasks: Number of prediction tasks
|
||||
n_epochs: Number of fine-tuning epochs
|
||||
|
||||
Returns:
|
||||
Trained model and evaluation results
|
||||
"""
|
||||
print("=" * 70)
|
||||
print("Fine-tuning ChemBERTa")
|
||||
print("=" * 70)
|
||||
print("\nChemBERTa is a BERT model pretrained on 77M molecules from ZINC15.")
|
||||
print("It uses SMILES strings as input and has learned rich molecular")
|
||||
print("representations that transfer well to downstream tasks.")
|
||||
|
||||
print(f"\nLoading pretrained ChemBERTa model...")
|
||||
model = dc.models.HuggingFaceModel(
|
||||
model=PRETRAINED_MODELS['chemberta']['model_id'],
|
||||
task=task_type,
|
||||
n_tasks=n_tasks,
|
||||
batch_size=32,
|
||||
learning_rate=2e-5 # Lower LR for fine-tuning
|
||||
)
|
||||
|
||||
print(f"\nFine-tuning for {n_epochs} epochs...")
|
||||
print("(This may take a while on the first run as the model is downloaded)")
|
||||
model.fit(train_dataset, nb_epoch=n_epochs)
|
||||
print("Fine-tuning complete!")
|
||||
|
||||
# Evaluate
|
||||
print("\n" + "=" * 70)
|
||||
print("Model Evaluation")
|
||||
print("=" * 70)
|
||||
|
||||
if task_type == 'classification':
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
|
||||
dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
|
||||
]
|
||||
else:
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
|
||||
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
|
||||
]
|
||||
|
||||
results = {}
|
||||
for name, dataset in [('Train', train_dataset), ('Valid', valid_dataset), ('Test', test_dataset)]:
|
||||
print(f"\n{name} Set:")
|
||||
scores = model.evaluate(dataset, metrics)
|
||||
results[name] = scores
|
||||
for metric_name, score in scores.items():
|
||||
print(f" {metric_name}: {score:.4f}")
|
||||
|
||||
return model, results
|
||||
|
||||
|
||||
def train_grover(train_dataset, test_dataset, task_type='classification', n_tasks=1, n_epochs=20):
|
||||
"""
|
||||
Fine-tune GROVER on a dataset.
|
||||
|
||||
Args:
|
||||
train_dataset: Training dataset
|
||||
test_dataset: Test dataset
|
||||
task_type: 'classification' or 'regression'
|
||||
n_tasks: Number of prediction tasks
|
||||
n_epochs: Number of fine-tuning epochs
|
||||
|
||||
Returns:
|
||||
Trained model and evaluation results
|
||||
"""
|
||||
print("=" * 70)
|
||||
print("Fine-tuning GROVER")
|
||||
print("=" * 70)
|
||||
print("\nGROVER is a graph transformer pretrained on 10M molecules using")
|
||||
print("self-supervised learning. It learns both node and graph-level")
|
||||
print("representations through masked atom/bond prediction tasks.")
|
||||
|
||||
print(f"\nCreating GROVER model...")
|
||||
model = dc.models.GroverModel(
|
||||
task=task_type,
|
||||
n_tasks=n_tasks,
|
||||
model_dir='./grover_pretrained'
|
||||
)
|
||||
|
||||
print(f"\nFine-tuning for {n_epochs} epochs...")
|
||||
model.fit(train_dataset, nb_epoch=n_epochs)
|
||||
print("Fine-tuning complete!")
|
||||
|
||||
# Evaluate
|
||||
print("\n" + "=" * 70)
|
||||
print("Model Evaluation")
|
||||
print("=" * 70)
|
||||
|
||||
if task_type == 'classification':
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
|
||||
dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
|
||||
]
|
||||
else:
|
||||
metrics = [
|
||||
dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
|
||||
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
|
||||
]
|
||||
|
||||
results = {}
|
||||
for name, dataset in [('Train', train_dataset), ('Test', test_dataset)]:
|
||||
print(f"\n{name} Set:")
|
||||
scores = model.evaluate(dataset, metrics)
|
||||
results[name] = scores
|
||||
for metric_name, score in scores.items():
|
||||
print(f" {metric_name}: {score:.4f}")
|
||||
|
||||
return model, results
|
||||
|
||||
|
||||
def load_molnet_dataset(dataset_name, model_type):
|
||||
"""
|
||||
Load a MoleculeNet dataset with appropriate featurization.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of MoleculeNet dataset
|
||||
model_type: Type of pretrained model being used
|
||||
|
||||
Returns:
|
||||
tasks, train/valid/test datasets, transformers
|
||||
"""
|
||||
# Map of MoleculeNet datasets
|
||||
molnet_datasets = {
|
||||
'tox21': dc.molnet.load_tox21,
|
||||
'bbbp': dc.molnet.load_bbbp,
|
||||
'bace': dc.molnet.load_bace_classification,
|
||||
'hiv': dc.molnet.load_hiv,
|
||||
'delaney': dc.molnet.load_delaney,
|
||||
'freesolv': dc.molnet.load_freesolv,
|
||||
'lipo': dc.molnet.load_lipo
|
||||
}
|
||||
|
||||
if dataset_name not in molnet_datasets:
|
||||
raise ValueError(f"Unknown dataset: {dataset_name}")
|
||||
|
||||
# ChemBERTa and MolFormer use raw SMILES
|
||||
if model_type in ['chemberta', 'molformer']:
|
||||
featurizer = 'Raw'
|
||||
# GROVER needs graph features
|
||||
elif model_type == 'grover':
|
||||
featurizer = 'GraphConv'
|
||||
else:
|
||||
featurizer = 'ECFP'
|
||||
|
||||
print(f"\nLoading {dataset_name} dataset...")
|
||||
load_func = molnet_datasets[dataset_name]
|
||||
tasks, datasets, transformers = load_func(
|
||||
featurizer=featurizer,
|
||||
splitter='scaffold'
|
||||
)
|
||||
|
||||
return tasks, datasets, transformers
|
||||
|
||||
|
||||
def load_custom_dataset(data_path, target_cols, smiles_col, model_type):
|
||||
"""
|
||||
Load a custom CSV dataset.
|
||||
|
||||
Args:
|
||||
data_path: Path to CSV file
|
||||
target_cols: List of target column names
|
||||
smiles_col: Name of SMILES column
|
||||
model_type: Type of pretrained model being used
|
||||
|
||||
Returns:
|
||||
train, valid, test datasets
|
||||
"""
|
||||
print(f"\nLoading custom data from {data_path}...")
|
||||
|
||||
# Choose featurizer based on model
|
||||
if model_type in ['chemberta', 'molformer']:
|
||||
featurizer = dc.feat.DummyFeaturizer() # Models handle featurization
|
||||
elif model_type == 'grover':
|
||||
featurizer = dc.feat.MolGraphConvFeaturizer()
|
||||
else:
|
||||
featurizer = dc.feat.CircularFingerprint()
|
||||
|
||||
loader = dc.data.CSVLoader(
|
||||
tasks=target_cols,
|
||||
feature_field=smiles_col,
|
||||
featurizer=featurizer
|
||||
)
|
||||
dataset = loader.create_dataset(data_path)
|
||||
|
||||
print(f"Loaded {len(dataset)} molecules")
|
||||
|
||||
# Split data
|
||||
print("Splitting data with scaffold splitter...")
|
||||
splitter = dc.splits.ScaffoldSplitter()
|
||||
train, valid, test = splitter.train_valid_test_split(
|
||||
dataset,
|
||||
frac_train=0.8,
|
||||
frac_valid=0.1,
|
||||
frac_test=0.1
|
||||
)
|
||||
|
||||
print(f" Training: {len(train)}")
|
||||
print(f" Validation: {len(valid)}")
|
||||
print(f" Test: {len(test)}")
|
||||
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Transfer learning for molecular property prediction'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
choices=list(PRETRAINED_MODELS.keys()),
|
||||
required=True,
|
||||
help='Pretrained model to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
choices=['tox21', 'bbbp', 'bace', 'hiv', 'delaney', 'freesolv', 'lipo'],
|
||||
default=None,
|
||||
help='MoleculeNet dataset to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--data',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to custom CSV file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--target',
|
||||
nargs='+',
|
||||
default=['target'],
|
||||
help='Target column name(s) for custom data'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--smiles-col',
|
||||
type=str,
|
||||
default='smiles',
|
||||
help='SMILES column name for custom data'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--task-type',
|
||||
type=str,
|
||||
choices=['classification', 'regression'],
|
||||
default='classification',
|
||||
help='Type of prediction task'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Number of fine-tuning epochs'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if args.dataset is None and args.data is None:
|
||||
print("Error: Must specify either --dataset or --data", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if args.dataset and args.data:
|
||||
print("Error: Cannot specify both --dataset and --data", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Print model info
|
||||
model_info = PRETRAINED_MODELS[args.model]
|
||||
print("\n" + "=" * 70)
|
||||
print(f"Transfer Learning with {model_info['name']}")
|
||||
print("=" * 70)
|
||||
print(f"\n{model_info['description']}")
|
||||
|
||||
try:
|
||||
# Load dataset
|
||||
if args.dataset:
|
||||
tasks, datasets, transformers = load_molnet_dataset(args.dataset, args.model)
|
||||
train, valid, test = datasets
|
||||
task_type = 'classification' if args.dataset in ['tox21', 'bbbp', 'bace', 'hiv'] else 'regression'
|
||||
n_tasks = len(tasks)
|
||||
else:
|
||||
train, valid, test = load_custom_dataset(
|
||||
args.data,
|
||||
args.target,
|
||||
args.smiles_col,
|
||||
args.model
|
||||
)
|
||||
task_type = args.task_type
|
||||
n_tasks = len(args.target)
|
||||
|
||||
# Train model
|
||||
if args.model == 'chemberta':
|
||||
model, results = train_chemberta(
|
||||
train, valid, test,
|
||||
task_type=task_type,
|
||||
n_tasks=n_tasks,
|
||||
n_epochs=args.epochs
|
||||
)
|
||||
elif args.model == 'grover':
|
||||
model, results = train_grover(
|
||||
train, test,
|
||||
task_type=task_type,
|
||||
n_tasks=n_tasks,
|
||||
n_epochs=args.epochs
|
||||
)
|
||||
else:
|
||||
print(f"Error: Model {args.model} not yet implemented", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Transfer Learning Complete!")
|
||||
print("=" * 70)
|
||||
print("\nTip: Pretrained models often work best with:")
|
||||
print(" - Small datasets (< 1000 samples)")
|
||||
print(" - Lower learning rates (1e-5 to 5e-5)")
|
||||
print(" - Fewer epochs (5-20)")
|
||||
print(" - Avoiding overfitting through early stopping")
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user