Files
gh-k-dense-ai-claude-scient…/skills/pytdc/scripts/load_and_split_data.py
2025-11-30 08:30:10 +08:00

215 lines
6.2 KiB
Python

#!/usr/bin/env python3
"""
TDC Data Loading and Splitting Template
This script demonstrates how to load TDC datasets and apply different
splitting strategies for model training and evaluation.
Usage:
python load_and_split_data.py
"""
from tdc.single_pred import ADME
from tdc.multi_pred import DTI
from tdc import Evaluator
import pandas as pd
def load_single_pred_example():
"""
Example: Loading and splitting a single-prediction dataset (ADME)
"""
print("=" * 60)
print("Example 1: Single-Prediction Task (ADME)")
print("=" * 60)
# Load Caco2 dataset (intestinal permeability)
print("\nLoading Caco2_Wang dataset...")
data = ADME(name='Caco2_Wang')
# Get basic dataset info
print(f"\nDataset size: {len(data.get_data())} molecules")
data.print_stats()
# Method 1: Scaffold split (default, recommended)
print("\n--- Scaffold Split ---")
split = data.get_split(method='scaffold', seed=42, frac=[0.7, 0.1, 0.2])
train = split['train']
valid = split['valid']
test = split['test']
print(f"Train: {len(train)} molecules")
print(f"Valid: {len(valid)} molecules")
print(f"Test: {len(test)} molecules")
# Display sample data
print("\nSample training data:")
print(train.head(3))
# Method 2: Random split
print("\n--- Random Split ---")
split_random = data.get_split(method='random', seed=42, frac=[0.8, 0.1, 0.1])
print(f"Train: {len(split_random['train'])} molecules")
print(f"Valid: {len(split_random['valid'])} molecules")
print(f"Test: {len(split_random['test'])} molecules")
return split
def load_multi_pred_example():
"""
Example: Loading and splitting a multi-prediction dataset (DTI)
"""
print("\n" + "=" * 60)
print("Example 2: Multi-Prediction Task (DTI)")
print("=" * 60)
# Load BindingDB Kd dataset (drug-target interactions)
print("\nLoading BindingDB_Kd dataset...")
data = DTI(name='BindingDB_Kd')
# Get basic dataset info
full_data = data.get_data()
print(f"\nDataset size: {len(full_data)} drug-target pairs")
print(f"Unique drugs: {full_data['Drug_ID'].nunique()}")
print(f"Unique targets: {full_data['Target_ID'].nunique()}")
# Method 1: Random split
print("\n--- Random Split ---")
split_random = data.get_split(method='random', seed=42)
print(f"Train: {len(split_random['train'])} pairs")
print(f"Valid: {len(split_random['valid'])} pairs")
print(f"Test: {len(split_random['test'])} pairs")
# Method 2: Cold drug split (unseen drugs in test)
print("\n--- Cold Drug Split ---")
split_cold_drug = data.get_split(method='cold_drug', seed=42)
train = split_cold_drug['train']
test = split_cold_drug['test']
# Verify no drug overlap
train_drugs = set(train['Drug_ID'])
test_drugs = set(test['Drug_ID'])
overlap = train_drugs & test_drugs
print(f"Train: {len(train)} pairs, {len(train_drugs)} unique drugs")
print(f"Test: {len(test)} pairs, {len(test_drugs)} unique drugs")
print(f"Drug overlap: {len(overlap)} (should be 0)")
# Method 3: Cold target split (unseen targets in test)
print("\n--- Cold Target Split ---")
split_cold_target = data.get_split(method='cold_target', seed=42)
train = split_cold_target['train']
test = split_cold_target['test']
train_targets = set(train['Target_ID'])
test_targets = set(test['Target_ID'])
overlap = train_targets & test_targets
print(f"Train: {len(train)} pairs, {len(train_targets)} unique targets")
print(f"Test: {len(test)} pairs, {len(test_targets)} unique targets")
print(f"Target overlap: {len(overlap)} (should be 0)")
# Display sample data
print("\nSample DTI data:")
print(full_data.head(3))
return split_cold_drug
def evaluation_example(split):
"""
Example: Evaluating model predictions with TDC evaluators
"""
print("\n" + "=" * 60)
print("Example 3: Model Evaluation")
print("=" * 60)
test = split['test']
# For demonstration, create dummy predictions
# In practice, replace with your model's predictions
import numpy as np
np.random.seed(42)
# Simulate predictions (replace with model.predict(test['Drug']))
y_true = test['Y'].values
y_pred = y_true + np.random.normal(0, 0.5, len(y_true)) # Add noise
# Evaluate with different metrics
print("\nEvaluating predictions...")
# Regression metrics
mae_evaluator = Evaluator(name='MAE')
mae = mae_evaluator(y_true, y_pred)
print(f"MAE: {mae:.4f}")
rmse_evaluator = Evaluator(name='RMSE')
rmse = rmse_evaluator(y_true, y_pred)
print(f"RMSE: {rmse:.4f}")
r2_evaluator = Evaluator(name='R2')
r2 = r2_evaluator(y_true, y_pred)
print(f"R²: {r2:.4f}")
spearman_evaluator = Evaluator(name='Spearman')
spearman = spearman_evaluator(y_true, y_pred)
print(f"Spearman: {spearman:.4f}")
def custom_split_example():
"""
Example: Creating custom splits with different fractions
"""
print("\n" + "=" * 60)
print("Example 4: Custom Split Fractions")
print("=" * 60)
data = ADME(name='HIA_Hou')
# Custom split fractions
custom_fracs = [
([0.6, 0.2, 0.2], "60/20/20 split"),
([0.8, 0.1, 0.1], "80/10/10 split"),
([0.7, 0.15, 0.15], "70/15/15 split")
]
for frac, description in custom_fracs:
split = data.get_split(method='scaffold', seed=42, frac=frac)
print(f"\n{description}:")
print(f" Train: {len(split['train'])} ({frac[0]*100:.0f}%)")
print(f" Valid: {len(split['valid'])} ({frac[1]*100:.0f}%)")
print(f" Test: {len(split['test'])} ({frac[2]*100:.0f}%)")
def main():
"""
Main function to run all examples
"""
print("\n" + "=" * 60)
print("TDC Data Loading and Splitting Examples")
print("=" * 60)
# Example 1: Single prediction with scaffold split
split = load_single_pred_example()
# Example 2: Multi prediction with cold splits
dti_split = load_multi_pred_example()
# Example 3: Model evaluation
evaluation_example(split)
# Example 4: Custom split fractions
custom_split_example()
print("\n" + "=" * 60)
print("Examples completed!")
print("=" * 60)
if __name__ == "__main__":
main()