Files
gh-k-dense-ai-claude-scient…/skills/deepchem/scripts/transfer_learning.py
2025-11-30 08:30:10 +08:00

376 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Transfer Learning Script for DeepChem
Use pretrained models (ChemBERTa, GROVER, MolFormer) for molecular property prediction
with transfer learning. Particularly useful for small datasets.
Usage:
python transfer_learning.py --model chemberta --data my_data.csv --target activity
python transfer_learning.py --model grover --dataset bbbp
"""
import argparse
import deepchem as dc
import sys
PRETRAINED_MODELS = {
'chemberta': {
'name': 'ChemBERTa',
'description': 'BERT pretrained on 77M molecules from ZINC15',
'model_id': 'seyonec/ChemBERTa-zinc-base-v1'
},
'grover': {
'name': 'GROVER',
'description': 'Graph transformer pretrained on 10M molecules',
'model_id': None # GROVER uses its own loading mechanism
},
'molformer': {
'name': 'MolFormer',
'description': 'Transformer pretrained on molecular structures',
'model_id': 'ibm/MoLFormer-XL-both-10pct'
}
}
def train_chemberta(train_dataset, valid_dataset, test_dataset, task_type='classification', n_tasks=1, n_epochs=10):
"""
Fine-tune ChemBERTa on a dataset.
Args:
train_dataset: Training dataset
valid_dataset: Validation dataset
test_dataset: Test dataset
task_type: 'classification' or 'regression'
n_tasks: Number of prediction tasks
n_epochs: Number of fine-tuning epochs
Returns:
Trained model and evaluation results
"""
print("=" * 70)
print("Fine-tuning ChemBERTa")
print("=" * 70)
print("\nChemBERTa is a BERT model pretrained on 77M molecules from ZINC15.")
print("It uses SMILES strings as input and has learned rich molecular")
print("representations that transfer well to downstream tasks.")
print(f"\nLoading pretrained ChemBERTa model...")
model = dc.models.HuggingFaceModel(
model=PRETRAINED_MODELS['chemberta']['model_id'],
task=task_type,
n_tasks=n_tasks,
batch_size=32,
learning_rate=2e-5 # Lower LR for fine-tuning
)
print(f"\nFine-tuning for {n_epochs} epochs...")
print("(This may take a while on the first run as the model is downloaded)")
model.fit(train_dataset, nb_epoch=n_epochs)
print("Fine-tuning complete!")
# Evaluate
print("\n" + "=" * 70)
print("Model Evaluation")
print("=" * 70)
if task_type == 'classification':
metrics = [
dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
]
else:
metrics = [
dc.metrics.Metric(dc.metrics.r2_score, name=''),
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
]
results = {}
for name, dataset in [('Train', train_dataset), ('Valid', valid_dataset), ('Test', test_dataset)]:
print(f"\n{name} Set:")
scores = model.evaluate(dataset, metrics)
results[name] = scores
for metric_name, score in scores.items():
print(f" {metric_name}: {score:.4f}")
return model, results
def train_grover(train_dataset, test_dataset, task_type='classification', n_tasks=1, n_epochs=20):
"""
Fine-tune GROVER on a dataset.
Args:
train_dataset: Training dataset
test_dataset: Test dataset
task_type: 'classification' or 'regression'
n_tasks: Number of prediction tasks
n_epochs: Number of fine-tuning epochs
Returns:
Trained model and evaluation results
"""
print("=" * 70)
print("Fine-tuning GROVER")
print("=" * 70)
print("\nGROVER is a graph transformer pretrained on 10M molecules using")
print("self-supervised learning. It learns both node and graph-level")
print("representations through masked atom/bond prediction tasks.")
print(f"\nCreating GROVER model...")
model = dc.models.GroverModel(
task=task_type,
n_tasks=n_tasks,
model_dir='./grover_pretrained'
)
print(f"\nFine-tuning for {n_epochs} epochs...")
model.fit(train_dataset, nb_epoch=n_epochs)
print("Fine-tuning complete!")
# Evaluate
print("\n" + "=" * 70)
print("Model Evaluation")
print("=" * 70)
if task_type == 'classification':
metrics = [
dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
]
else:
metrics = [
dc.metrics.Metric(dc.metrics.r2_score, name=''),
dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
]
results = {}
for name, dataset in [('Train', train_dataset), ('Test', test_dataset)]:
print(f"\n{name} Set:")
scores = model.evaluate(dataset, metrics)
results[name] = scores
for metric_name, score in scores.items():
print(f" {metric_name}: {score:.4f}")
return model, results
def load_molnet_dataset(dataset_name, model_type):
"""
Load a MoleculeNet dataset with appropriate featurization.
Args:
dataset_name: Name of MoleculeNet dataset
model_type: Type of pretrained model being used
Returns:
tasks, train/valid/test datasets, transformers
"""
# Map of MoleculeNet datasets
molnet_datasets = {
'tox21': dc.molnet.load_tox21,
'bbbp': dc.molnet.load_bbbp,
'bace': dc.molnet.load_bace_classification,
'hiv': dc.molnet.load_hiv,
'delaney': dc.molnet.load_delaney,
'freesolv': dc.molnet.load_freesolv,
'lipo': dc.molnet.load_lipo
}
if dataset_name not in molnet_datasets:
raise ValueError(f"Unknown dataset: {dataset_name}")
# ChemBERTa and MolFormer use raw SMILES
if model_type in ['chemberta', 'molformer']:
featurizer = 'Raw'
# GROVER needs graph features
elif model_type == 'grover':
featurizer = 'GraphConv'
else:
featurizer = 'ECFP'
print(f"\nLoading {dataset_name} dataset...")
load_func = molnet_datasets[dataset_name]
tasks, datasets, transformers = load_func(
featurizer=featurizer,
splitter='scaffold'
)
return tasks, datasets, transformers
def load_custom_dataset(data_path, target_cols, smiles_col, model_type):
"""
Load a custom CSV dataset.
Args:
data_path: Path to CSV file
target_cols: List of target column names
smiles_col: Name of SMILES column
model_type: Type of pretrained model being used
Returns:
train, valid, test datasets
"""
print(f"\nLoading custom data from {data_path}...")
# Choose featurizer based on model
if model_type in ['chemberta', 'molformer']:
featurizer = dc.feat.DummyFeaturizer() # Models handle featurization
elif model_type == 'grover':
featurizer = dc.feat.MolGraphConvFeaturizer()
else:
featurizer = dc.feat.CircularFingerprint()
loader = dc.data.CSVLoader(
tasks=target_cols,
feature_field=smiles_col,
featurizer=featurizer
)
dataset = loader.create_dataset(data_path)
print(f"Loaded {len(dataset)} molecules")
# Split data
print("Splitting data with scaffold splitter...")
splitter = dc.splits.ScaffoldSplitter()
train, valid, test = splitter.train_valid_test_split(
dataset,
frac_train=0.8,
frac_valid=0.1,
frac_test=0.1
)
print(f" Training: {len(train)}")
print(f" Validation: {len(valid)}")
print(f" Test: {len(test)}")
return train, valid, test
def main():
parser = argparse.ArgumentParser(
description='Transfer learning for molecular property prediction'
)
parser.add_argument(
'--model',
type=str,
choices=list(PRETRAINED_MODELS.keys()),
required=True,
help='Pretrained model to use'
)
parser.add_argument(
'--dataset',
type=str,
choices=['tox21', 'bbbp', 'bace', 'hiv', 'delaney', 'freesolv', 'lipo'],
default=None,
help='MoleculeNet dataset to use'
)
parser.add_argument(
'--data',
type=str,
default=None,
help='Path to custom CSV file'
)
parser.add_argument(
'--target',
nargs='+',
default=['target'],
help='Target column name(s) for custom data'
)
parser.add_argument(
'--smiles-col',
type=str,
default='smiles',
help='SMILES column name for custom data'
)
parser.add_argument(
'--task-type',
type=str,
choices=['classification', 'regression'],
default='classification',
help='Type of prediction task'
)
parser.add_argument(
'--epochs',
type=int,
default=10,
help='Number of fine-tuning epochs'
)
args = parser.parse_args()
# Validate arguments
if args.dataset is None and args.data is None:
print("Error: Must specify either --dataset or --data", file=sys.stderr)
return 1
if args.dataset and args.data:
print("Error: Cannot specify both --dataset and --data", file=sys.stderr)
return 1
# Print model info
model_info = PRETRAINED_MODELS[args.model]
print("\n" + "=" * 70)
print(f"Transfer Learning with {model_info['name']}")
print("=" * 70)
print(f"\n{model_info['description']}")
try:
# Load dataset
if args.dataset:
tasks, datasets, transformers = load_molnet_dataset(args.dataset, args.model)
train, valid, test = datasets
task_type = 'classification' if args.dataset in ['tox21', 'bbbp', 'bace', 'hiv'] else 'regression'
n_tasks = len(tasks)
else:
train, valid, test = load_custom_dataset(
args.data,
args.target,
args.smiles_col,
args.model
)
task_type = args.task_type
n_tasks = len(args.target)
# Train model
if args.model == 'chemberta':
model, results = train_chemberta(
train, valid, test,
task_type=task_type,
n_tasks=n_tasks,
n_epochs=args.epochs
)
elif args.model == 'grover':
model, results = train_grover(
train, test,
task_type=task_type,
n_tasks=n_tasks,
n_epochs=args.epochs
)
else:
print(f"Error: Model {args.model} not yet implemented", file=sys.stderr)
return 1
print("\n" + "=" * 70)
print("Transfer Learning Complete!")
print("=" * 70)
print("\nTip: Pretrained models often work best with:")
print(" - Small datasets (< 1000 samples)")
print(" - Lower learning rates (1e-5 to 5e-5)")
print(" - Fewer epochs (5-20)")
print(" - Avoiding overfitting through early stopping")
return 0
except Exception as e:
print(f"\nError: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
return 1
if __name__ == '__main__':
sys.exit(main())