Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 08:30:10 +08:00
commit f0bd18fb4e
824 changed files with 331919 additions and 0 deletions

View File

@@ -0,0 +1,334 @@
#!/usr/bin/env python3
"""
DiffDock Results Analysis Script
This script analyzes DiffDock prediction results, extracting confidence scores,
ranking predictions, and generating summary reports.
Usage:
python analyze_results.py results/output_dir/
python analyze_results.py results/ --top 50 --threshold 0.0
python analyze_results.py results/ --export summary.csv
"""
import argparse
import os
import sys
import json
from pathlib import Path
from collections import defaultdict
import re
def parse_confidence_scores(results_dir):
"""
Parse confidence scores from DiffDock output directory.
Args:
results_dir: Path to DiffDock results directory
Returns:
dict: Dictionary mapping complex names to their predictions and scores
"""
results = {}
results_path = Path(results_dir)
# Check if this is a single complex or batch results
sdf_files = list(results_path.glob("*.sdf"))
if sdf_files:
# Single complex output
results['single_complex'] = parse_single_complex(results_path)
else:
# Batch output - multiple subdirectories
for subdir in results_path.iterdir():
if subdir.is_dir():
complex_results = parse_single_complex(subdir)
if complex_results:
results[subdir.name] = complex_results
return results
def parse_single_complex(complex_dir):
"""Parse results for a single complex."""
predictions = []
# Look for SDF files with rank information
for sdf_file in complex_dir.glob("*.sdf"):
filename = sdf_file.name
# Extract rank from filename (e.g., "rank_1.sdf" or "index_0_rank_1.sdf")
rank_match = re.search(r'rank_(\d+)', filename)
if rank_match:
rank = int(rank_match.group(1))
# Try to extract confidence score from filename or separate file
confidence = extract_confidence_score(sdf_file, complex_dir)
predictions.append({
'rank': rank,
'file': sdf_file.name,
'path': str(sdf_file),
'confidence': confidence
})
# Sort by rank
predictions.sort(key=lambda x: x['rank'])
return {'predictions': predictions} if predictions else None
def extract_confidence_score(sdf_file, complex_dir):
"""
Extract confidence score for a prediction.
Tries multiple methods:
1. Read from confidence_scores.txt file
2. Parse from SDF file properties
3. Extract from filename if present
"""
# Method 1: confidence_scores.txt
confidence_file = complex_dir / "confidence_scores.txt"
if confidence_file.exists():
try:
with open(confidence_file) as f:
lines = f.readlines()
# Extract rank from filename
rank_match = re.search(r'rank_(\d+)', sdf_file.name)
if rank_match:
rank = int(rank_match.group(1))
if rank <= len(lines):
return float(lines[rank - 1].strip())
except Exception:
pass
# Method 2: Parse from SDF file
try:
with open(sdf_file) as f:
content = f.read()
# Look for confidence score in SDF properties
conf_match = re.search(r'confidence[:\s]+(-?\d+\.?\d*)', content, re.IGNORECASE)
if conf_match:
return float(conf_match.group(1))
except Exception:
pass
# Method 3: Filename (e.g., "rank_1_conf_0.95.sdf")
conf_match = re.search(r'conf_(-?\d+\.?\d*)', sdf_file.name)
if conf_match:
return float(conf_match.group(1))
return None
def classify_confidence(score):
"""Classify confidence score into categories."""
if score is None:
return "Unknown"
elif score > 0:
return "High"
elif score > -1.5:
return "Moderate"
else:
return "Low"
def print_summary(results, top_n=None, min_confidence=None):
"""Print a formatted summary of results."""
print("\n" + "="*80)
print("DiffDock Results Summary")
print("="*80)
all_predictions = []
for complex_name, data in results.items():
predictions = data.get('predictions', [])
print(f"\n{complex_name}")
print("-" * 80)
if not predictions:
print(" No predictions found")
continue
# Filter by confidence if specified
filtered_predictions = predictions
if min_confidence is not None:
filtered_predictions = [p for p in predictions if p['confidence'] is not None and p['confidence'] >= min_confidence]
# Limit to top N if specified
if top_n is not None:
filtered_predictions = filtered_predictions[:top_n]
for pred in filtered_predictions:
confidence = pred['confidence']
confidence_class = classify_confidence(confidence)
conf_str = f"{confidence:>7.3f}" if confidence is not None else " N/A"
print(f" Rank {pred['rank']:2d}: Confidence = {conf_str} ({confidence_class:8s}) | {pred['file']}")
# Add to all predictions for overall statistics
if confidence is not None:
all_predictions.append((complex_name, pred['rank'], confidence))
# Show statistics for this complex
if filtered_predictions and any(p['confidence'] is not None for p in filtered_predictions):
confidences = [p['confidence'] for p in filtered_predictions if p['confidence'] is not None]
print(f"\n Statistics: {len(filtered_predictions)} predictions")
print(f" Mean confidence: {sum(confidences)/len(confidences):.3f}")
print(f" Max confidence: {max(confidences):.3f}")
print(f" Min confidence: {min(confidences):.3f}")
# Overall statistics
if all_predictions:
print("\n" + "="*80)
print("Overall Statistics")
print("="*80)
confidences = [conf for _, _, conf in all_predictions]
print(f" Total predictions: {len(all_predictions)}")
print(f" Total complexes: {len(results)}")
print(f" Mean confidence: {sum(confidences)/len(confidences):.3f}")
print(f" Max confidence: {max(confidences):.3f}")
print(f" Min confidence: {min(confidences):.3f}")
# Confidence distribution
high = sum(1 for c in confidences if c > 0)
moderate = sum(1 for c in confidences if -1.5 < c <= 0)
low = sum(1 for c in confidences if c <= -1.5)
print(f"\n Confidence distribution:")
print(f" High (> 0): {high:4d} ({100*high/len(confidences):5.1f}%)")
print(f" Moderate (-1.5 to 0): {moderate:4d} ({100*moderate/len(confidences):5.1f}%)")
print(f" Low (< -1.5): {low:4d} ({100*low/len(confidences):5.1f}%)")
print("\n" + "="*80)
def export_to_csv(results, output_path):
"""Export results to CSV file."""
import csv
with open(output_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['complex_name', 'rank', 'confidence', 'confidence_class', 'file_path'])
for complex_name, data in results.items():
predictions = data.get('predictions', [])
for pred in predictions:
confidence = pred['confidence']
confidence_class = classify_confidence(confidence)
conf_value = confidence if confidence is not None else ''
writer.writerow([
complex_name,
pred['rank'],
conf_value,
confidence_class,
pred['path']
])
print(f"✓ Exported results to: {output_path}")
def get_top_predictions(results, n=10, sort_by='confidence'):
"""Get top N predictions across all complexes."""
all_predictions = []
for complex_name, data in results.items():
predictions = data.get('predictions', [])
for pred in predictions:
if pred['confidence'] is not None:
all_predictions.append({
'complex': complex_name,
**pred
})
# Sort by confidence (descending)
all_predictions.sort(key=lambda x: x['confidence'], reverse=True)
return all_predictions[:n]
def print_top_predictions(results, n=10):
"""Print top N predictions across all complexes."""
top_preds = get_top_predictions(results, n)
print("\n" + "="*80)
print(f"Top {n} Predictions Across All Complexes")
print("="*80)
for i, pred in enumerate(top_preds, 1):
confidence_class = classify_confidence(pred['confidence'])
print(f"{i:2d}. {pred['complex']:30s} | Rank {pred['rank']:2d} | "
f"Confidence: {pred['confidence']:7.3f} ({confidence_class})")
print("="*80)
def main():
parser = argparse.ArgumentParser(
description='Analyze DiffDock prediction results',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Analyze all results in directory
python analyze_results.py results/output_dir/
# Show only top 5 predictions per complex
python analyze_results.py results/ --top 5
# Filter by confidence threshold
python analyze_results.py results/ --threshold 0.0
# Export to CSV
python analyze_results.py results/ --export summary.csv
# Show top 20 predictions across all complexes
python analyze_results.py results/ --best 20
"""
)
parser.add_argument('results_dir', help='Path to DiffDock results directory')
parser.add_argument('--top', '-t', type=int,
help='Show only top N predictions per complex')
parser.add_argument('--threshold', type=float,
help='Minimum confidence threshold')
parser.add_argument('--export', '-e', metavar='FILE',
help='Export results to CSV file')
parser.add_argument('--best', '-b', type=int, metavar='N',
help='Show top N predictions across all complexes')
args = parser.parse_args()
# Validate results directory
if not os.path.exists(args.results_dir):
print(f"Error: Results directory not found: {args.results_dir}")
return 1
# Parse results
print(f"Analyzing results in: {args.results_dir}")
results = parse_confidence_scores(args.results_dir)
if not results:
print("No DiffDock results found in directory")
return 1
# Print summary
print_summary(results, top_n=args.top, min_confidence=args.threshold)
# Print top predictions across all complexes
if args.best:
print_top_predictions(results, args.best)
# Export to CSV if requested
if args.export:
export_to_csv(results, args.export)
return 0
if __name__ == '__main__':
sys.exit(main())

View File

@@ -0,0 +1,254 @@
#!/usr/bin/env python3
"""
DiffDock Batch CSV Preparation and Validation Script
This script helps prepare and validate CSV files for DiffDock batch processing.
It checks for required columns, validates file paths, and ensures SMILES strings
are properly formatted.
Usage:
python prepare_batch_csv.py input.csv --validate
python prepare_batch_csv.py --create --output batch_input.csv
"""
import argparse
import os
import sys
import pandas as pd
from pathlib import Path
try:
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
RDKIT_AVAILABLE = True
except ImportError:
RDKIT_AVAILABLE = False
print("Warning: RDKit not available. SMILES validation will be skipped.")
def validate_smiles(smiles_string):
"""Validate a SMILES string using RDKit."""
if not RDKIT_AVAILABLE:
return True, "RDKit not available for validation"
try:
mol = Chem.MolFromSmiles(smiles_string)
if mol is None:
return False, "Invalid SMILES structure"
return True, "Valid SMILES"
except Exception as e:
return False, str(e)
def validate_file_path(file_path, base_dir=None):
"""Validate that a file path exists."""
if pd.isna(file_path) or file_path == "":
return True, "Empty (will use protein_sequence)"
# Handle relative paths
if base_dir:
full_path = Path(base_dir) / file_path
else:
full_path = Path(file_path)
if full_path.exists():
return True, f"File exists: {full_path}"
else:
return False, f"File not found: {full_path}"
def validate_csv(csv_path, base_dir=None):
"""
Validate a DiffDock batch input CSV file.
Args:
csv_path: Path to CSV file
base_dir: Base directory for relative paths (default: CSV directory)
Returns:
bool: True if validation passes
list: List of validation messages
"""
messages = []
valid = True
# Read CSV
try:
df = pd.read_csv(csv_path)
messages.append(f"✓ Successfully read CSV with {len(df)} rows")
except Exception as e:
messages.append(f"✗ Error reading CSV: {e}")
return False, messages
# Check required columns
required_cols = ['complex_name', 'protein_path', 'ligand_description', 'protein_sequence']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
messages.append(f"✗ Missing required columns: {', '.join(missing_cols)}")
valid = False
else:
messages.append("✓ All required columns present")
# Set base directory
if base_dir is None:
base_dir = Path(csv_path).parent
# Validate each row
for idx, row in df.iterrows():
row_msgs = []
# Check complex name
if pd.isna(row['complex_name']) or row['complex_name'] == "":
row_msgs.append("Missing complex_name")
valid = False
# Check that either protein_path or protein_sequence is provided
has_protein_path = not pd.isna(row['protein_path']) and row['protein_path'] != ""
has_protein_seq = not pd.isna(row['protein_sequence']) and row['protein_sequence'] != ""
if not has_protein_path and not has_protein_seq:
row_msgs.append("Must provide either protein_path or protein_sequence")
valid = False
elif has_protein_path and has_protein_seq:
row_msgs.append("Warning: Both protein_path and protein_sequence provided, will use protein_path")
# Validate protein path if provided
if has_protein_path:
file_valid, msg = validate_file_path(row['protein_path'], base_dir)
if not file_valid:
row_msgs.append(f"Protein file issue: {msg}")
valid = False
# Validate ligand description
if pd.isna(row['ligand_description']) or row['ligand_description'] == "":
row_msgs.append("Missing ligand_description")
valid = False
else:
ligand_desc = row['ligand_description']
# Check if it's a file path or SMILES
if os.path.exists(ligand_desc) or "/" in ligand_desc or "\\" in ligand_desc:
# Likely a file path
file_valid, msg = validate_file_path(ligand_desc, base_dir)
if not file_valid:
row_msgs.append(f"Ligand file issue: {msg}")
valid = False
else:
# Likely a SMILES string
smiles_valid, msg = validate_smiles(ligand_desc)
if not smiles_valid:
row_msgs.append(f"SMILES issue: {msg}")
valid = False
if row_msgs:
messages.append(f"\nRow {idx + 1} ({row.get('complex_name', 'unnamed')}):")
for msg in row_msgs:
messages.append(f" - {msg}")
# Summary
messages.append(f"\n{'='*60}")
if valid:
messages.append("✓ CSV validation PASSED - ready for DiffDock")
else:
messages.append("✗ CSV validation FAILED - please fix issues above")
return valid, messages
def create_template_csv(output_path, num_examples=3):
"""Create a template CSV file with example entries."""
examples = {
'complex_name': ['example1', 'example2', 'example3'][:num_examples],
'protein_path': ['protein1.pdb', '', 'protein3.pdb'][:num_examples],
'ligand_description': [
'CC(=O)Oc1ccccc1C(=O)O', # Aspirin SMILES
'COc1ccc(C#N)cc1', # Example SMILES
'ligand.sdf' # Example file path
][:num_examples],
'protein_sequence': [
'', # Empty - using PDB file
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', # GFP sequence
'' # Empty - using PDB file
][:num_examples]
}
df = pd.DataFrame(examples)
df.to_csv(output_path, index=False)
return df
def main():
parser = argparse.ArgumentParser(
description='Prepare and validate DiffDock batch CSV files',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Validate existing CSV
python prepare_batch_csv.py input.csv --validate
# Create template CSV
python prepare_batch_csv.py --create --output batch_template.csv
# Create template with 5 example rows
python prepare_batch_csv.py --create --output template.csv --num-examples 5
# Validate with custom base directory for relative paths
python prepare_batch_csv.py input.csv --validate --base-dir /path/to/data/
"""
)
parser.add_argument('csv_file', nargs='?', help='CSV file to validate')
parser.add_argument('--validate', action='store_true',
help='Validate the CSV file')
parser.add_argument('--create', action='store_true',
help='Create a template CSV file')
parser.add_argument('--output', '-o', help='Output path for template CSV')
parser.add_argument('--num-examples', type=int, default=3,
help='Number of example rows in template (default: 3)')
parser.add_argument('--base-dir', help='Base directory for relative file paths')
args = parser.parse_args()
# Create template
if args.create:
output_path = args.output or 'diffdock_batch_template.csv'
df = create_template_csv(output_path, args.num_examples)
print(f"✓ Created template CSV: {output_path}")
print(f"\nTemplate contents:")
print(df.to_string(index=False))
print(f"\nEdit this file with your protein-ligand pairs and run with:")
print(f" python -m inference --config default_inference_args.yaml \\")
print(f" --protein_ligand_csv {output_path} --out_dir results/")
return 0
# Validate CSV
if args.validate or args.csv_file:
if not args.csv_file:
print("Error: CSV file required for validation")
parser.print_help()
return 1
if not os.path.exists(args.csv_file):
print(f"Error: CSV file not found: {args.csv_file}")
return 1
print(f"Validating: {args.csv_file}")
print("="*60)
valid, messages = validate_csv(args.csv_file, args.base_dir)
for msg in messages:
print(msg)
return 0 if valid else 1
# No action specified
parser.print_help()
return 1
if __name__ == '__main__':
sys.exit(main())

View File

@@ -0,0 +1,278 @@
#!/usr/bin/env python3
"""
DiffDock Environment Setup Checker
This script verifies that the DiffDock environment is properly configured
and all dependencies are available.
Usage:
python setup_check.py
python setup_check.py --verbose
"""
import argparse
import sys
import os
from pathlib import Path
def check_python_version():
"""Check Python version."""
import sys
version = sys.version_info
print("Checking Python version...")
if version.major == 3 and version.minor >= 8:
print(f" ✓ Python {version.major}.{version.minor}.{version.micro}")
return True
else:
print(f" ✗ Python {version.major}.{version.minor}.{version.micro} "
f"(requires Python 3.8 or higher)")
return False
def check_package(package_name, import_name=None, version_attr='__version__'):
"""Check if a Python package is installed."""
if import_name is None:
import_name = package_name
try:
module = __import__(import_name)
version = getattr(module, version_attr, 'unknown')
print(f"{package_name:20s} (version: {version})")
return True
except ImportError:
print(f"{package_name:20s} (not installed)")
return False
def check_pytorch():
"""Check PyTorch installation and CUDA availability."""
print("\nChecking PyTorch...")
try:
import torch
print(f" ✓ PyTorch version: {torch.__version__}")
# Check CUDA
if torch.cuda.is_available():
print(f" ✓ CUDA available: {torch.cuda.get_device_name(0)}")
print(f" - CUDA version: {torch.version.cuda}")
print(f" - Number of GPUs: {torch.cuda.device_count()}")
return True, True
else:
print(f" ⚠ CUDA not available (will run on CPU)")
return True, False
except ImportError:
print(f" ✗ PyTorch not installed")
return False, False
def check_pytorch_geometric():
"""Check PyTorch Geometric installation."""
print("\nChecking PyTorch Geometric...")
packages = [
('torch-geometric', 'torch_geometric'),
('torch-scatter', 'torch_scatter'),
('torch-sparse', 'torch_sparse'),
('torch-cluster', 'torch_cluster'),
]
all_ok = True
for pkg_name, import_name in packages:
if not check_package(pkg_name, import_name):
all_ok = False
return all_ok
def check_core_dependencies():
"""Check core DiffDock dependencies."""
print("\nChecking core dependencies...")
dependencies = [
('numpy', 'numpy'),
('scipy', 'scipy'),
('pandas', 'pandas'),
('rdkit', 'rdkit', 'rdBase.__version__'),
('biopython', 'Bio', '__version__'),
('pytorch-lightning', 'pytorch_lightning'),
('PyYAML', 'yaml'),
]
all_ok = True
for dep in dependencies:
pkg_name = dep[0]
import_name = dep[1]
version_attr = dep[2] if len(dep) > 2 else '__version__'
if not check_package(pkg_name, import_name, version_attr):
all_ok = False
return all_ok
def check_esm():
"""Check ESM (protein language model) installation."""
print("\nChecking ESM (for protein sequence folding)...")
try:
import esm
print(f" ✓ ESM installed (version: {esm.__version__ if hasattr(esm, '__version__') else 'unknown'})")
return True
except ImportError:
print(f" ⚠ ESM not installed (needed for protein sequence folding)")
print(f" Install with: pip install fair-esm")
return False
def check_diffdock_installation():
"""Check if DiffDock is properly installed/cloned."""
print("\nChecking DiffDock installation...")
# Look for key files
key_files = [
'inference.py',
'default_inference_args.yaml',
'environment.yml',
]
found_files = []
missing_files = []
for filename in key_files:
if os.path.exists(filename):
found_files.append(filename)
else:
missing_files.append(filename)
if found_files:
print(f" ✓ Found DiffDock files in current directory:")
for f in found_files:
print(f" - {f}")
else:
print(f" ⚠ DiffDock files not found in current directory")
print(f" Current directory: {os.getcwd()}")
print(f" Make sure you're in the DiffDock repository root")
# Check for model checkpoints
model_dir = Path('./workdir/v1.1/score_model')
confidence_dir = Path('./workdir/v1.1/confidence_model')
if model_dir.exists() and confidence_dir.exists():
print(f" ✓ Model checkpoints found")
else:
print(f" ⚠ Model checkpoints not found in ./workdir/v1.1/")
print(f" Models will be downloaded on first run")
return len(found_files) > 0
def print_installation_instructions():
"""Print installation instructions if setup is incomplete."""
print("\n" + "="*80)
print("Installation Instructions")
print("="*80)
print("""
If DiffDock is not installed, follow these steps:
1. Clone the repository:
git clone https://github.com/gcorso/DiffDock.git
cd DiffDock
2. Create conda environment:
conda env create --file environment.yml
conda activate diffdock
3. Verify installation:
python setup_check.py
For Docker installation:
docker pull rbgcsail/diffdock
docker run -it --gpus all --entrypoint /bin/bash rbgcsail/diffdock
micromamba activate diffdock
For more information, visit: https://github.com/gcorso/DiffDock
""")
def print_performance_notes(has_cuda):
"""Print performance notes based on available hardware."""
print("\n" + "="*80)
print("Performance Notes")
print("="*80)
if has_cuda:
print("""
✓ GPU detected - DiffDock will run efficiently
Expected performance:
- First run: ~2-5 minutes (pre-computing SO(2)/SO(3) tables)
- Subsequent runs: ~10-60 seconds per complex (depending on settings)
- Batch processing: Highly efficient with GPU
""")
else:
print("""
⚠ No GPU detected - DiffDock will run on CPU
Expected performance:
- CPU inference is SIGNIFICANTLY slower than GPU
- Single complex: Several minutes to hours
- Batch processing: Not recommended on CPU
Recommendation: Use GPU for practical applications
- Cloud options: Google Colab, AWS, or other cloud GPU services
- Local: Install CUDA-capable GPU
""")
def main():
parser = argparse.ArgumentParser(
description='Check DiffDock environment setup',
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--verbose', '-v', action='store_true',
help='Show detailed version information')
args = parser.parse_args()
print("="*80)
print("DiffDock Environment Setup Checker")
print("="*80)
checks = []
# Run all checks
checks.append(("Python version", check_python_version()))
pytorch_ok, has_cuda = check_pytorch()
checks.append(("PyTorch", pytorch_ok))
checks.append(("PyTorch Geometric", check_pytorch_geometric()))
checks.append(("Core dependencies", check_core_dependencies()))
checks.append(("ESM", check_esm()))
checks.append(("DiffDock files", check_diffdock_installation()))
# Summary
print("\n" + "="*80)
print("Summary")
print("="*80)
all_passed = all(result for _, result in checks)
for check_name, result in checks:
status = "✓ PASS" if result else "✗ FAIL"
print(f" {status:8s} - {check_name}")
if all_passed:
print("\n✓ All checks passed! DiffDock is ready to use.")
print_performance_notes(has_cuda)
return 0
else:
print("\n✗ Some checks failed. Please install missing dependencies.")
print_installation_instructions()
return 1
if __name__ == '__main__':
sys.exit(main())