225 lines
6.4 KiB
Python
225 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Molecular Solubility Prediction Script
|
|
|
|
This script trains a model to predict aqueous solubility from SMILES strings
|
|
using the Delaney (ESOL) dataset as an example. Can be adapted for custom datasets.
|
|
|
|
Usage:
|
|
python predict_solubility.py --data custom_data.csv --smiles-col smiles --target-col solubility
|
|
python predict_solubility.py # Uses Delaney dataset by default
|
|
"""
|
|
|
|
import argparse
|
|
import deepchem as dc
|
|
import numpy as np
|
|
import sys
|
|
|
|
|
|
def train_solubility_model(data_path=None, smiles_col='smiles', target_col='measured log solubility in mols per litre'):
|
|
"""
|
|
Train a solubility prediction model.
|
|
|
|
Args:
|
|
data_path: Path to CSV file with SMILES and solubility data. If None, uses Delaney dataset.
|
|
smiles_col: Name of column containing SMILES strings
|
|
target_col: Name of column containing solubility values
|
|
|
|
Returns:
|
|
Trained model, test dataset, and transformers
|
|
"""
|
|
print("=" * 60)
|
|
print("DeepChem Solubility Prediction")
|
|
print("=" * 60)
|
|
|
|
# Load data
|
|
if data_path is None:
|
|
print("\nUsing Delaney (ESOL) benchmark dataset...")
|
|
tasks, datasets, transformers = dc.molnet.load_delaney(
|
|
featurizer='ECFP',
|
|
splitter='scaffold'
|
|
)
|
|
train, valid, test = datasets
|
|
else:
|
|
print(f"\nLoading custom data from {data_path}...")
|
|
featurizer = dc.feat.CircularFingerprint(radius=2, size=2048)
|
|
loader = dc.data.CSVLoader(
|
|
tasks=[target_col],
|
|
feature_field=smiles_col,
|
|
featurizer=featurizer
|
|
)
|
|
dataset = loader.create_dataset(data_path)
|
|
|
|
# Split data
|
|
print("Splitting data with scaffold splitter...")
|
|
splitter = dc.splits.ScaffoldSplitter()
|
|
train, valid, test = splitter.train_valid_test_split(
|
|
dataset,
|
|
frac_train=0.8,
|
|
frac_valid=0.1,
|
|
frac_test=0.1
|
|
)
|
|
|
|
# Normalize data
|
|
print("Normalizing features and targets...")
|
|
transformers = [
|
|
dc.trans.NormalizationTransformer(
|
|
transform_y=True,
|
|
dataset=train
|
|
)
|
|
]
|
|
for transformer in transformers:
|
|
train = transformer.transform(train)
|
|
valid = transformer.transform(valid)
|
|
test = transformer.transform(test)
|
|
|
|
tasks = [target_col]
|
|
|
|
print(f"\nDataset sizes:")
|
|
print(f" Training: {len(train)} molecules")
|
|
print(f" Validation: {len(valid)} molecules")
|
|
print(f" Test: {len(test)} molecules")
|
|
|
|
# Create model
|
|
print("\nCreating multitask regressor...")
|
|
model = dc.models.MultitaskRegressor(
|
|
n_tasks=len(tasks),
|
|
n_features=2048, # ECFP fingerprint size
|
|
layer_sizes=[1000, 500],
|
|
dropouts=0.25,
|
|
learning_rate=0.001,
|
|
batch_size=50
|
|
)
|
|
|
|
# Train model
|
|
print("\nTraining model...")
|
|
model.fit(train, nb_epoch=50)
|
|
print("Training complete!")
|
|
|
|
# Evaluate model
|
|
print("\n" + "=" * 60)
|
|
print("Model Evaluation")
|
|
print("=" * 60)
|
|
|
|
metrics = [
|
|
dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
|
|
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
|
|
dc.metrics.Metric(dc.metrics.root_mean_squared_error, name='RMSE'),
|
|
]
|
|
|
|
for dataset_name, dataset in [('Train', train), ('Valid', valid), ('Test', test)]:
|
|
print(f"\n{dataset_name} Set:")
|
|
scores = model.evaluate(dataset, metrics)
|
|
for metric_name, score in scores.items():
|
|
print(f" {metric_name}: {score:.4f}")
|
|
|
|
return model, test, transformers
|
|
|
|
|
|
def predict_new_molecules(model, smiles_list, transformers=None):
|
|
"""
|
|
Predict solubility for new molecules.
|
|
|
|
Args:
|
|
model: Trained DeepChem model
|
|
smiles_list: List of SMILES strings
|
|
transformers: List of data transformers to apply
|
|
|
|
Returns:
|
|
Array of predictions
|
|
"""
|
|
print("\n" + "=" * 60)
|
|
print("Predicting New Molecules")
|
|
print("=" * 60)
|
|
|
|
# Featurize new molecules
|
|
featurizer = dc.feat.CircularFingerprint(radius=2, size=2048)
|
|
features = featurizer.featurize(smiles_list)
|
|
|
|
# Create dataset
|
|
new_dataset = dc.data.NumpyDataset(X=features)
|
|
|
|
# Apply transformers (if any)
|
|
if transformers:
|
|
for transformer in transformers:
|
|
new_dataset = transformer.transform(new_dataset)
|
|
|
|
# Predict
|
|
predictions = model.predict(new_dataset)
|
|
|
|
# Display results
|
|
print("\nPredictions:")
|
|
for smiles, pred in zip(smiles_list, predictions):
|
|
print(f" {smiles:30s} -> {pred[0]:.3f} log(mol/L)")
|
|
|
|
return predictions
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Train a molecular solubility prediction model'
|
|
)
|
|
parser.add_argument(
|
|
'--data',
|
|
type=str,
|
|
default=None,
|
|
help='Path to CSV file with molecular data'
|
|
)
|
|
parser.add_argument(
|
|
'--smiles-col',
|
|
type=str,
|
|
default='smiles',
|
|
help='Name of column containing SMILES strings'
|
|
)
|
|
parser.add_argument(
|
|
'--target-col',
|
|
type=str,
|
|
default='solubility',
|
|
help='Name of column containing target values'
|
|
)
|
|
parser.add_argument(
|
|
'--predict',
|
|
nargs='+',
|
|
default=None,
|
|
help='SMILES strings to predict after training'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Train model
|
|
try:
|
|
model, test_set, transformers = train_solubility_model(
|
|
data_path=args.data,
|
|
smiles_col=args.smiles_col,
|
|
target_col=args.target_col
|
|
)
|
|
except Exception as e:
|
|
print(f"\nError during training: {e}", file=sys.stderr)
|
|
return 1
|
|
|
|
# Make predictions on new molecules if provided
|
|
if args.predict:
|
|
try:
|
|
predict_new_molecules(model, args.predict, transformers)
|
|
except Exception as e:
|
|
print(f"\nError during prediction: {e}", file=sys.stderr)
|
|
return 1
|
|
else:
|
|
# Example predictions
|
|
example_smiles = [
|
|
'CCO', # Ethanol
|
|
'CC(=O)O', # Acetic acid
|
|
'c1ccccc1', # Benzene
|
|
'CN1C=NC2=C1C(=O)N(C(=O)N2C)C', # Caffeine
|
|
]
|
|
predict_new_molecules(model, example_smiles, transformers)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Complete!")
|
|
print("=" * 60)
|
|
return 0
|
|
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(main())
|