Files
gh-k-dense-ai-claude-scient…/skills/deepchem/scripts/predict_solubility.py
2025-11-30 08:30:10 +08:00

225 lines
6.4 KiB
Python

#!/usr/bin/env python3
"""
Molecular Solubility Prediction Script
This script trains a model to predict aqueous solubility from SMILES strings
using the Delaney (ESOL) dataset as an example. Can be adapted for custom datasets.
Usage:
python predict_solubility.py --data custom_data.csv --smiles-col smiles --target-col solubility
python predict_solubility.py # Uses Delaney dataset by default
"""
import argparse
import deepchem as dc
import numpy as np
import sys
def train_solubility_model(data_path=None, smiles_col='smiles', target_col='measured log solubility in mols per litre'):
"""
Train a solubility prediction model.
Args:
data_path: Path to CSV file with SMILES and solubility data. If None, uses Delaney dataset.
smiles_col: Name of column containing SMILES strings
target_col: Name of column containing solubility values
Returns:
Trained model, test dataset, and transformers
"""
print("=" * 60)
print("DeepChem Solubility Prediction")
print("=" * 60)
# Load data
if data_path is None:
print("\nUsing Delaney (ESOL) benchmark dataset...")
tasks, datasets, transformers = dc.molnet.load_delaney(
featurizer='ECFP',
splitter='scaffold'
)
train, valid, test = datasets
else:
print(f"\nLoading custom data from {data_path}...")
featurizer = dc.feat.CircularFingerprint(radius=2, size=2048)
loader = dc.data.CSVLoader(
tasks=[target_col],
feature_field=smiles_col,
featurizer=featurizer
)
dataset = loader.create_dataset(data_path)
# Split data
print("Splitting data with scaffold splitter...")
splitter = dc.splits.ScaffoldSplitter()
train, valid, test = splitter.train_valid_test_split(
dataset,
frac_train=0.8,
frac_valid=0.1,
frac_test=0.1
)
# Normalize data
print("Normalizing features and targets...")
transformers = [
dc.trans.NormalizationTransformer(
transform_y=True,
dataset=train
)
]
for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)
tasks = [target_col]
print(f"\nDataset sizes:")
print(f" Training: {len(train)} molecules")
print(f" Validation: {len(valid)} molecules")
print(f" Test: {len(test)} molecules")
# Create model
print("\nCreating multitask regressor...")
model = dc.models.MultitaskRegressor(
n_tasks=len(tasks),
n_features=2048, # ECFP fingerprint size
layer_sizes=[1000, 500],
dropouts=0.25,
learning_rate=0.001,
batch_size=50
)
# Train model
print("\nTraining model...")
model.fit(train, nb_epoch=50)
print("Training complete!")
# Evaluate model
print("\n" + "=" * 60)
print("Model Evaluation")
print("=" * 60)
metrics = [
dc.metrics.Metric(dc.metrics.r2_score, name=''),
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
dc.metrics.Metric(dc.metrics.root_mean_squared_error, name='RMSE'),
]
for dataset_name, dataset in [('Train', train), ('Valid', valid), ('Test', test)]:
print(f"\n{dataset_name} Set:")
scores = model.evaluate(dataset, metrics)
for metric_name, score in scores.items():
print(f" {metric_name}: {score:.4f}")
return model, test, transformers
def predict_new_molecules(model, smiles_list, transformers=None):
"""
Predict solubility for new molecules.
Args:
model: Trained DeepChem model
smiles_list: List of SMILES strings
transformers: List of data transformers to apply
Returns:
Array of predictions
"""
print("\n" + "=" * 60)
print("Predicting New Molecules")
print("=" * 60)
# Featurize new molecules
featurizer = dc.feat.CircularFingerprint(radius=2, size=2048)
features = featurizer.featurize(smiles_list)
# Create dataset
new_dataset = dc.data.NumpyDataset(X=features)
# Apply transformers (if any)
if transformers:
for transformer in transformers:
new_dataset = transformer.transform(new_dataset)
# Predict
predictions = model.predict(new_dataset)
# Display results
print("\nPredictions:")
for smiles, pred in zip(smiles_list, predictions):
print(f" {smiles:30s} -> {pred[0]:.3f} log(mol/L)")
return predictions
def main():
parser = argparse.ArgumentParser(
description='Train a molecular solubility prediction model'
)
parser.add_argument(
'--data',
type=str,
default=None,
help='Path to CSV file with molecular data'
)
parser.add_argument(
'--smiles-col',
type=str,
default='smiles',
help='Name of column containing SMILES strings'
)
parser.add_argument(
'--target-col',
type=str,
default='solubility',
help='Name of column containing target values'
)
parser.add_argument(
'--predict',
nargs='+',
default=None,
help='SMILES strings to predict after training'
)
args = parser.parse_args()
# Train model
try:
model, test_set, transformers = train_solubility_model(
data_path=args.data,
smiles_col=args.smiles_col,
target_col=args.target_col
)
except Exception as e:
print(f"\nError during training: {e}", file=sys.stderr)
return 1
# Make predictions on new molecules if provided
if args.predict:
try:
predict_new_molecules(model, args.predict, transformers)
except Exception as e:
print(f"\nError during prediction: {e}", file=sys.stderr)
return 1
else:
# Example predictions
example_smiles = [
'CCO', # Ethanol
'CC(=O)O', # Acetic acid
'c1ccccc1', # Benzene
'CN1C=NC2=C1C(=O)N(C(=O)N2C)C', # Caffeine
]
predict_new_molecules(model, example_smiles, transformers)
print("\n" + "=" * 60)
print("Complete!")
print("=" * 60)
return 0
if __name__ == '__main__':
sys.exit(main())