339 lines
9.5 KiB
Python
339 lines
9.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Graph Neural Network Training Script
|
|
|
|
This script demonstrates training Graph Convolutional Networks (GCNs) and other
|
|
graph-based models for molecular property prediction.
|
|
|
|
Usage:
|
|
python graph_neural_network.py --dataset tox21 --model gcn
|
|
python graph_neural_network.py --dataset bbbp --model attentivefp
|
|
python graph_neural_network.py --data custom.csv --task-type regression
|
|
"""
|
|
|
|
import argparse
|
|
import deepchem as dc
|
|
import sys
|
|
|
|
|
|
AVAILABLE_MODELS = {
|
|
'gcn': 'Graph Convolutional Network',
|
|
'gat': 'Graph Attention Network',
|
|
'attentivefp': 'Attentive Fingerprint',
|
|
'mpnn': 'Message Passing Neural Network',
|
|
'dmpnn': 'Directed Message Passing Neural Network'
|
|
}
|
|
|
|
MOLNET_DATASETS = {
|
|
'tox21': ('classification', 12),
|
|
'bbbp': ('classification', 1),
|
|
'bace': ('classification', 1),
|
|
'hiv': ('classification', 1),
|
|
'delaney': ('regression', 1),
|
|
'freesolv': ('regression', 1),
|
|
'lipo': ('regression', 1)
|
|
}
|
|
|
|
|
|
def create_model(model_type, n_tasks, mode='classification'):
|
|
"""
|
|
Create a graph neural network model.
|
|
|
|
Args:
|
|
model_type: Type of model ('gcn', 'gat', 'attentivefp', etc.)
|
|
n_tasks: Number of prediction tasks
|
|
mode: 'classification' or 'regression'
|
|
|
|
Returns:
|
|
DeepChem model
|
|
"""
|
|
if model_type == 'gcn':
|
|
return dc.models.GCNModel(
|
|
n_tasks=n_tasks,
|
|
mode=mode,
|
|
batch_size=128,
|
|
learning_rate=0.001,
|
|
dropout=0.0
|
|
)
|
|
elif model_type == 'gat':
|
|
return dc.models.GATModel(
|
|
n_tasks=n_tasks,
|
|
mode=mode,
|
|
batch_size=128,
|
|
learning_rate=0.001
|
|
)
|
|
elif model_type == 'attentivefp':
|
|
return dc.models.AttentiveFPModel(
|
|
n_tasks=n_tasks,
|
|
mode=mode,
|
|
batch_size=128,
|
|
learning_rate=0.001
|
|
)
|
|
elif model_type == 'mpnn':
|
|
return dc.models.MPNNModel(
|
|
n_tasks=n_tasks,
|
|
mode=mode,
|
|
batch_size=128,
|
|
learning_rate=0.001
|
|
)
|
|
elif model_type == 'dmpnn':
|
|
return dc.models.DMPNNModel(
|
|
n_tasks=n_tasks,
|
|
mode=mode,
|
|
batch_size=128,
|
|
learning_rate=0.001
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
|
|
def train_on_molnet(dataset_name, model_type, n_epochs=50):
|
|
"""
|
|
Train a graph neural network on a MoleculeNet benchmark dataset.
|
|
|
|
Args:
|
|
dataset_name: Name of MoleculeNet dataset
|
|
model_type: Type of model to train
|
|
n_epochs: Number of training epochs
|
|
|
|
Returns:
|
|
Trained model and test scores
|
|
"""
|
|
print("=" * 70)
|
|
print(f"Training {AVAILABLE_MODELS[model_type]} on {dataset_name.upper()}")
|
|
print("=" * 70)
|
|
|
|
# Get dataset info
|
|
task_type, n_tasks_default = MOLNET_DATASETS[dataset_name]
|
|
|
|
# Load dataset with graph featurization
|
|
print(f"\nLoading {dataset_name} dataset with GraphConv featurizer...")
|
|
load_func = getattr(dc.molnet, f'load_{dataset_name}')
|
|
tasks, datasets, transformers = load_func(
|
|
featurizer='GraphConv',
|
|
splitter='scaffold'
|
|
)
|
|
train, valid, test = datasets
|
|
|
|
n_tasks = len(tasks)
|
|
print(f"\nDataset Information:")
|
|
print(f" Task type: {task_type}")
|
|
print(f" Number of tasks: {n_tasks}")
|
|
print(f" Training samples: {len(train)}")
|
|
print(f" Validation samples: {len(valid)}")
|
|
print(f" Test samples: {len(test)}")
|
|
|
|
# Create model
|
|
print(f"\nCreating {AVAILABLE_MODELS[model_type]} model...")
|
|
model = create_model(model_type, n_tasks, mode=task_type)
|
|
|
|
# Train
|
|
print(f"\nTraining for {n_epochs} epochs...")
|
|
model.fit(train, nb_epoch=n_epochs)
|
|
print("Training complete!")
|
|
|
|
# Evaluate
|
|
print("\n" + "=" * 70)
|
|
print("Model Evaluation")
|
|
print("=" * 70)
|
|
|
|
if task_type == 'classification':
|
|
metrics = [
|
|
dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
|
|
dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
|
|
dc.metrics.Metric(dc.metrics.f1_score, name='F1'),
|
|
]
|
|
else:
|
|
metrics = [
|
|
dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
|
|
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
|
|
dc.metrics.Metric(dc.metrics.root_mean_squared_error, name='RMSE'),
|
|
]
|
|
|
|
results = {}
|
|
for dataset_name_eval, dataset in [('Train', train), ('Valid', valid), ('Test', test)]:
|
|
print(f"\n{dataset_name_eval} Set:")
|
|
scores = model.evaluate(dataset, metrics)
|
|
results[dataset_name_eval] = scores
|
|
for metric_name, score in scores.items():
|
|
print(f" {metric_name}: {score:.4f}")
|
|
|
|
return model, results
|
|
|
|
|
|
def train_on_custom_data(data_path, model_type, task_type, target_cols, smiles_col='smiles', n_epochs=50):
|
|
"""
|
|
Train a graph neural network on custom CSV data.
|
|
|
|
Args:
|
|
data_path: Path to CSV file
|
|
model_type: Type of model to train
|
|
task_type: 'classification' or 'regression'
|
|
target_cols: List of target column names
|
|
smiles_col: Name of SMILES column
|
|
n_epochs: Number of training epochs
|
|
|
|
Returns:
|
|
Trained model and test dataset
|
|
"""
|
|
print("=" * 70)
|
|
print(f"Training {AVAILABLE_MODELS[model_type]} on Custom Data")
|
|
print("=" * 70)
|
|
|
|
# Load and featurize data
|
|
print(f"\nLoading data from {data_path}...")
|
|
featurizer = dc.feat.MolGraphConvFeaturizer()
|
|
loader = dc.data.CSVLoader(
|
|
tasks=target_cols,
|
|
feature_field=smiles_col,
|
|
featurizer=featurizer
|
|
)
|
|
dataset = loader.create_dataset(data_path)
|
|
|
|
print(f"Loaded {len(dataset)} molecules")
|
|
|
|
# Split data
|
|
print("\nSplitting data with scaffold splitter...")
|
|
splitter = dc.splits.ScaffoldSplitter()
|
|
train, valid, test = splitter.train_valid_test_split(
|
|
dataset,
|
|
frac_train=0.8,
|
|
frac_valid=0.1,
|
|
frac_test=0.1
|
|
)
|
|
|
|
print(f" Training: {len(train)}")
|
|
print(f" Validation: {len(valid)}")
|
|
print(f" Test: {len(test)}")
|
|
|
|
# Create model
|
|
print(f"\nCreating {AVAILABLE_MODELS[model_type]} model...")
|
|
n_tasks = len(target_cols)
|
|
model = create_model(model_type, n_tasks, mode=task_type)
|
|
|
|
# Train
|
|
print(f"\nTraining for {n_epochs} epochs...")
|
|
model.fit(train, nb_epoch=n_epochs)
|
|
print("Training complete!")
|
|
|
|
# Evaluate
|
|
print("\n" + "=" * 70)
|
|
print("Model Evaluation")
|
|
print("=" * 70)
|
|
|
|
if task_type == 'classification':
|
|
metrics = [
|
|
dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
|
|
dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
|
|
]
|
|
else:
|
|
metrics = [
|
|
dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
|
|
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
|
|
]
|
|
|
|
for dataset_name, dataset in [('Train', train), ('Valid', valid), ('Test', test)]:
|
|
print(f"\n{dataset_name} Set:")
|
|
scores = model.evaluate(dataset, metrics)
|
|
for metric_name, score in scores.items():
|
|
print(f" {metric_name}: {score:.4f}")
|
|
|
|
return model, test
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Train graph neural networks for molecular property prediction'
|
|
)
|
|
parser.add_argument(
|
|
'--model',
|
|
type=str,
|
|
choices=list(AVAILABLE_MODELS.keys()),
|
|
default='gcn',
|
|
help='Type of graph neural network model'
|
|
)
|
|
parser.add_argument(
|
|
'--dataset',
|
|
type=str,
|
|
choices=list(MOLNET_DATASETS.keys()),
|
|
default=None,
|
|
help='MoleculeNet dataset to use'
|
|
)
|
|
parser.add_argument(
|
|
'--data',
|
|
type=str,
|
|
default=None,
|
|
help='Path to custom CSV file'
|
|
)
|
|
parser.add_argument(
|
|
'--task-type',
|
|
type=str,
|
|
choices=['classification', 'regression'],
|
|
default='classification',
|
|
help='Type of prediction task (for custom data)'
|
|
)
|
|
parser.add_argument(
|
|
'--targets',
|
|
nargs='+',
|
|
default=['target'],
|
|
help='Names of target columns (for custom data)'
|
|
)
|
|
parser.add_argument(
|
|
'--smiles-col',
|
|
type=str,
|
|
default='smiles',
|
|
help='Name of SMILES column'
|
|
)
|
|
parser.add_argument(
|
|
'--epochs',
|
|
type=int,
|
|
default=50,
|
|
help='Number of training epochs'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate arguments
|
|
if args.dataset is None and args.data is None:
|
|
print("Error: Must specify either --dataset (MoleculeNet) or --data (custom CSV)",
|
|
file=sys.stderr)
|
|
return 1
|
|
|
|
if args.dataset and args.data:
|
|
print("Error: Cannot specify both --dataset and --data",
|
|
file=sys.stderr)
|
|
return 1
|
|
|
|
# Train model
|
|
try:
|
|
if args.dataset:
|
|
model, results = train_on_molnet(
|
|
args.dataset,
|
|
args.model,
|
|
n_epochs=args.epochs
|
|
)
|
|
else:
|
|
model, test_set = train_on_custom_data(
|
|
args.data,
|
|
args.model,
|
|
args.task_type,
|
|
args.targets,
|
|
smiles_col=args.smiles_col,
|
|
n_epochs=args.epochs
|
|
)
|
|
|
|
print("\n" + "=" * 70)
|
|
print("Training Complete!")
|
|
print("=" * 70)
|
|
return 0
|
|
|
|
except Exception as e:
|
|
print(f"\nError: {e}", file=sys.stderr)
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(main())
|