Files
gh-k-dense-ai-claude-scient…/skills/diffdock/scripts/prepare_batch_csv.py
2025-11-30 08:30:10 +08:00

255 lines
8.5 KiB
Python
Executable File

#!/usr/bin/env python3
"""
DiffDock Batch CSV Preparation and Validation Script
This script helps prepare and validate CSV files for DiffDock batch processing.
It checks for required columns, validates file paths, and ensures SMILES strings
are properly formatted.
Usage:
python prepare_batch_csv.py input.csv --validate
python prepare_batch_csv.py --create --output batch_input.csv
"""
import argparse
import os
import sys
import pandas as pd
from pathlib import Path
try:
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
RDKIT_AVAILABLE = True
except ImportError:
RDKIT_AVAILABLE = False
print("Warning: RDKit not available. SMILES validation will be skipped.")
def validate_smiles(smiles_string):
"""Validate a SMILES string using RDKit."""
if not RDKIT_AVAILABLE:
return True, "RDKit not available for validation"
try:
mol = Chem.MolFromSmiles(smiles_string)
if mol is None:
return False, "Invalid SMILES structure"
return True, "Valid SMILES"
except Exception as e:
return False, str(e)
def validate_file_path(file_path, base_dir=None):
"""Validate that a file path exists."""
if pd.isna(file_path) or file_path == "":
return True, "Empty (will use protein_sequence)"
# Handle relative paths
if base_dir:
full_path = Path(base_dir) / file_path
else:
full_path = Path(file_path)
if full_path.exists():
return True, f"File exists: {full_path}"
else:
return False, f"File not found: {full_path}"
def validate_csv(csv_path, base_dir=None):
"""
Validate a DiffDock batch input CSV file.
Args:
csv_path: Path to CSV file
base_dir: Base directory for relative paths (default: CSV directory)
Returns:
bool: True if validation passes
list: List of validation messages
"""
messages = []
valid = True
# Read CSV
try:
df = pd.read_csv(csv_path)
messages.append(f"✓ Successfully read CSV with {len(df)} rows")
except Exception as e:
messages.append(f"✗ Error reading CSV: {e}")
return False, messages
# Check required columns
required_cols = ['complex_name', 'protein_path', 'ligand_description', 'protein_sequence']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
messages.append(f"✗ Missing required columns: {', '.join(missing_cols)}")
valid = False
else:
messages.append("✓ All required columns present")
# Set base directory
if base_dir is None:
base_dir = Path(csv_path).parent
# Validate each row
for idx, row in df.iterrows():
row_msgs = []
# Check complex name
if pd.isna(row['complex_name']) or row['complex_name'] == "":
row_msgs.append("Missing complex_name")
valid = False
# Check that either protein_path or protein_sequence is provided
has_protein_path = not pd.isna(row['protein_path']) and row['protein_path'] != ""
has_protein_seq = not pd.isna(row['protein_sequence']) and row['protein_sequence'] != ""
if not has_protein_path and not has_protein_seq:
row_msgs.append("Must provide either protein_path or protein_sequence")
valid = False
elif has_protein_path and has_protein_seq:
row_msgs.append("Warning: Both protein_path and protein_sequence provided, will use protein_path")
# Validate protein path if provided
if has_protein_path:
file_valid, msg = validate_file_path(row['protein_path'], base_dir)
if not file_valid:
row_msgs.append(f"Protein file issue: {msg}")
valid = False
# Validate ligand description
if pd.isna(row['ligand_description']) or row['ligand_description'] == "":
row_msgs.append("Missing ligand_description")
valid = False
else:
ligand_desc = row['ligand_description']
# Check if it's a file path or SMILES
if os.path.exists(ligand_desc) or "/" in ligand_desc or "\\" in ligand_desc:
# Likely a file path
file_valid, msg = validate_file_path(ligand_desc, base_dir)
if not file_valid:
row_msgs.append(f"Ligand file issue: {msg}")
valid = False
else:
# Likely a SMILES string
smiles_valid, msg = validate_smiles(ligand_desc)
if not smiles_valid:
row_msgs.append(f"SMILES issue: {msg}")
valid = False
if row_msgs:
messages.append(f"\nRow {idx + 1} ({row.get('complex_name', 'unnamed')}):")
for msg in row_msgs:
messages.append(f" - {msg}")
# Summary
messages.append(f"\n{'='*60}")
if valid:
messages.append("✓ CSV validation PASSED - ready for DiffDock")
else:
messages.append("✗ CSV validation FAILED - please fix issues above")
return valid, messages
def create_template_csv(output_path, num_examples=3):
"""Create a template CSV file with example entries."""
examples = {
'complex_name': ['example1', 'example2', 'example3'][:num_examples],
'protein_path': ['protein1.pdb', '', 'protein3.pdb'][:num_examples],
'ligand_description': [
'CC(=O)Oc1ccccc1C(=O)O', # Aspirin SMILES
'COc1ccc(C#N)cc1', # Example SMILES
'ligand.sdf' # Example file path
][:num_examples],
'protein_sequence': [
'', # Empty - using PDB file
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', # GFP sequence
'' # Empty - using PDB file
][:num_examples]
}
df = pd.DataFrame(examples)
df.to_csv(output_path, index=False)
return df
def main():
parser = argparse.ArgumentParser(
description='Prepare and validate DiffDock batch CSV files',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Validate existing CSV
python prepare_batch_csv.py input.csv --validate
# Create template CSV
python prepare_batch_csv.py --create --output batch_template.csv
# Create template with 5 example rows
python prepare_batch_csv.py --create --output template.csv --num-examples 5
# Validate with custom base directory for relative paths
python prepare_batch_csv.py input.csv --validate --base-dir /path/to/data/
"""
)
parser.add_argument('csv_file', nargs='?', help='CSV file to validate')
parser.add_argument('--validate', action='store_true',
help='Validate the CSV file')
parser.add_argument('--create', action='store_true',
help='Create a template CSV file')
parser.add_argument('--output', '-o', help='Output path for template CSV')
parser.add_argument('--num-examples', type=int, default=3,
help='Number of example rows in template (default: 3)')
parser.add_argument('--base-dir', help='Base directory for relative file paths')
args = parser.parse_args()
# Create template
if args.create:
output_path = args.output or 'diffdock_batch_template.csv'
df = create_template_csv(output_path, args.num_examples)
print(f"✓ Created template CSV: {output_path}")
print(f"\nTemplate contents:")
print(df.to_string(index=False))
print(f"\nEdit this file with your protein-ligand pairs and run with:")
print(f" python -m inference --config default_inference_args.yaml \\")
print(f" --protein_ligand_csv {output_path} --out_dir results/")
return 0
# Validate CSV
if args.validate or args.csv_file:
if not args.csv_file:
print("Error: CSV file required for validation")
parser.print_help()
return 1
if not os.path.exists(args.csv_file):
print(f"Error: CSV file not found: {args.csv_file}")
return 1
print(f"Validating: {args.csv_file}")
print("="*60)
valid, messages = validate_csv(args.csv_file, args.base_dir)
for msg in messages:
print(msg)
return 0 if valid else 1
# No action specified
parser.print_help()
return 1
if __name__ == '__main__':
sys.exit(main())