Initial commit
This commit is contained in:
334
skills/diffdock/scripts/analyze_results.py
Executable file
334
skills/diffdock/scripts/analyze_results.py
Executable file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiffDock Results Analysis Script
|
||||
|
||||
This script analyzes DiffDock prediction results, extracting confidence scores,
|
||||
ranking predictions, and generating summary reports.
|
||||
|
||||
Usage:
|
||||
python analyze_results.py results/output_dir/
|
||||
python analyze_results.py results/ --top 50 --threshold 0.0
|
||||
python analyze_results.py results/ --export summary.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import re
|
||||
|
||||
|
||||
def parse_confidence_scores(results_dir):
|
||||
"""
|
||||
Parse confidence scores from DiffDock output directory.
|
||||
|
||||
Args:
|
||||
results_dir: Path to DiffDock results directory
|
||||
|
||||
Returns:
|
||||
dict: Dictionary mapping complex names to their predictions and scores
|
||||
"""
|
||||
results = {}
|
||||
results_path = Path(results_dir)
|
||||
|
||||
# Check if this is a single complex or batch results
|
||||
sdf_files = list(results_path.glob("*.sdf"))
|
||||
|
||||
if sdf_files:
|
||||
# Single complex output
|
||||
results['single_complex'] = parse_single_complex(results_path)
|
||||
else:
|
||||
# Batch output - multiple subdirectories
|
||||
for subdir in results_path.iterdir():
|
||||
if subdir.is_dir():
|
||||
complex_results = parse_single_complex(subdir)
|
||||
if complex_results:
|
||||
results[subdir.name] = complex_results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def parse_single_complex(complex_dir):
|
||||
"""Parse results for a single complex."""
|
||||
predictions = []
|
||||
|
||||
# Look for SDF files with rank information
|
||||
for sdf_file in complex_dir.glob("*.sdf"):
|
||||
filename = sdf_file.name
|
||||
|
||||
# Extract rank from filename (e.g., "rank_1.sdf" or "index_0_rank_1.sdf")
|
||||
rank_match = re.search(r'rank_(\d+)', filename)
|
||||
if rank_match:
|
||||
rank = int(rank_match.group(1))
|
||||
|
||||
# Try to extract confidence score from filename or separate file
|
||||
confidence = extract_confidence_score(sdf_file, complex_dir)
|
||||
|
||||
predictions.append({
|
||||
'rank': rank,
|
||||
'file': sdf_file.name,
|
||||
'path': str(sdf_file),
|
||||
'confidence': confidence
|
||||
})
|
||||
|
||||
# Sort by rank
|
||||
predictions.sort(key=lambda x: x['rank'])
|
||||
|
||||
return {'predictions': predictions} if predictions else None
|
||||
|
||||
|
||||
def extract_confidence_score(sdf_file, complex_dir):
|
||||
"""
|
||||
Extract confidence score for a prediction.
|
||||
|
||||
Tries multiple methods:
|
||||
1. Read from confidence_scores.txt file
|
||||
2. Parse from SDF file properties
|
||||
3. Extract from filename if present
|
||||
"""
|
||||
# Method 1: confidence_scores.txt
|
||||
confidence_file = complex_dir / "confidence_scores.txt"
|
||||
if confidence_file.exists():
|
||||
try:
|
||||
with open(confidence_file) as f:
|
||||
lines = f.readlines()
|
||||
# Extract rank from filename
|
||||
rank_match = re.search(r'rank_(\d+)', sdf_file.name)
|
||||
if rank_match:
|
||||
rank = int(rank_match.group(1))
|
||||
if rank <= len(lines):
|
||||
return float(lines[rank - 1].strip())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Method 2: Parse from SDF file
|
||||
try:
|
||||
with open(sdf_file) as f:
|
||||
content = f.read()
|
||||
# Look for confidence score in SDF properties
|
||||
conf_match = re.search(r'confidence[:\s]+(-?\d+\.?\d*)', content, re.IGNORECASE)
|
||||
if conf_match:
|
||||
return float(conf_match.group(1))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Method 3: Filename (e.g., "rank_1_conf_0.95.sdf")
|
||||
conf_match = re.search(r'conf_(-?\d+\.?\d*)', sdf_file.name)
|
||||
if conf_match:
|
||||
return float(conf_match.group(1))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def classify_confidence(score):
|
||||
"""Classify confidence score into categories."""
|
||||
if score is None:
|
||||
return "Unknown"
|
||||
elif score > 0:
|
||||
return "High"
|
||||
elif score > -1.5:
|
||||
return "Moderate"
|
||||
else:
|
||||
return "Low"
|
||||
|
||||
|
||||
def print_summary(results, top_n=None, min_confidence=None):
|
||||
"""Print a formatted summary of results."""
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("DiffDock Results Summary")
|
||||
print("="*80)
|
||||
|
||||
all_predictions = []
|
||||
|
||||
for complex_name, data in results.items():
|
||||
predictions = data.get('predictions', [])
|
||||
|
||||
print(f"\n{complex_name}")
|
||||
print("-" * 80)
|
||||
|
||||
if not predictions:
|
||||
print(" No predictions found")
|
||||
continue
|
||||
|
||||
# Filter by confidence if specified
|
||||
filtered_predictions = predictions
|
||||
if min_confidence is not None:
|
||||
filtered_predictions = [p for p in predictions if p['confidence'] is not None and p['confidence'] >= min_confidence]
|
||||
|
||||
# Limit to top N if specified
|
||||
if top_n is not None:
|
||||
filtered_predictions = filtered_predictions[:top_n]
|
||||
|
||||
for pred in filtered_predictions:
|
||||
confidence = pred['confidence']
|
||||
confidence_class = classify_confidence(confidence)
|
||||
|
||||
conf_str = f"{confidence:>7.3f}" if confidence is not None else " N/A"
|
||||
print(f" Rank {pred['rank']:2d}: Confidence = {conf_str} ({confidence_class:8s}) | {pred['file']}")
|
||||
|
||||
# Add to all predictions for overall statistics
|
||||
if confidence is not None:
|
||||
all_predictions.append((complex_name, pred['rank'], confidence))
|
||||
|
||||
# Show statistics for this complex
|
||||
if filtered_predictions and any(p['confidence'] is not None for p in filtered_predictions):
|
||||
confidences = [p['confidence'] for p in filtered_predictions if p['confidence'] is not None]
|
||||
print(f"\n Statistics: {len(filtered_predictions)} predictions")
|
||||
print(f" Mean confidence: {sum(confidences)/len(confidences):.3f}")
|
||||
print(f" Max confidence: {max(confidences):.3f}")
|
||||
print(f" Min confidence: {min(confidences):.3f}")
|
||||
|
||||
# Overall statistics
|
||||
if all_predictions:
|
||||
print("\n" + "="*80)
|
||||
print("Overall Statistics")
|
||||
print("="*80)
|
||||
|
||||
confidences = [conf for _, _, conf in all_predictions]
|
||||
print(f" Total predictions: {len(all_predictions)}")
|
||||
print(f" Total complexes: {len(results)}")
|
||||
print(f" Mean confidence: {sum(confidences)/len(confidences):.3f}")
|
||||
print(f" Max confidence: {max(confidences):.3f}")
|
||||
print(f" Min confidence: {min(confidences):.3f}")
|
||||
|
||||
# Confidence distribution
|
||||
high = sum(1 for c in confidences if c > 0)
|
||||
moderate = sum(1 for c in confidences if -1.5 < c <= 0)
|
||||
low = sum(1 for c in confidences if c <= -1.5)
|
||||
|
||||
print(f"\n Confidence distribution:")
|
||||
print(f" High (> 0): {high:4d} ({100*high/len(confidences):5.1f}%)")
|
||||
print(f" Moderate (-1.5 to 0): {moderate:4d} ({100*moderate/len(confidences):5.1f}%)")
|
||||
print(f" Low (< -1.5): {low:4d} ({100*low/len(confidences):5.1f}%)")
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
def export_to_csv(results, output_path):
|
||||
"""Export results to CSV file."""
|
||||
import csv
|
||||
|
||||
with open(output_path, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(['complex_name', 'rank', 'confidence', 'confidence_class', 'file_path'])
|
||||
|
||||
for complex_name, data in results.items():
|
||||
predictions = data.get('predictions', [])
|
||||
for pred in predictions:
|
||||
confidence = pred['confidence']
|
||||
confidence_class = classify_confidence(confidence)
|
||||
conf_value = confidence if confidence is not None else ''
|
||||
|
||||
writer.writerow([
|
||||
complex_name,
|
||||
pred['rank'],
|
||||
conf_value,
|
||||
confidence_class,
|
||||
pred['path']
|
||||
])
|
||||
|
||||
print(f"✓ Exported results to: {output_path}")
|
||||
|
||||
|
||||
def get_top_predictions(results, n=10, sort_by='confidence'):
|
||||
"""Get top N predictions across all complexes."""
|
||||
all_predictions = []
|
||||
|
||||
for complex_name, data in results.items():
|
||||
predictions = data.get('predictions', [])
|
||||
for pred in predictions:
|
||||
if pred['confidence'] is not None:
|
||||
all_predictions.append({
|
||||
'complex': complex_name,
|
||||
**pred
|
||||
})
|
||||
|
||||
# Sort by confidence (descending)
|
||||
all_predictions.sort(key=lambda x: x['confidence'], reverse=True)
|
||||
|
||||
return all_predictions[:n]
|
||||
|
||||
|
||||
def print_top_predictions(results, n=10):
|
||||
"""Print top N predictions across all complexes."""
|
||||
top_preds = get_top_predictions(results, n)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(f"Top {n} Predictions Across All Complexes")
|
||||
print("="*80)
|
||||
|
||||
for i, pred in enumerate(top_preds, 1):
|
||||
confidence_class = classify_confidence(pred['confidence'])
|
||||
print(f"{i:2d}. {pred['complex']:30s} | Rank {pred['rank']:2d} | "
|
||||
f"Confidence: {pred['confidence']:7.3f} ({confidence_class})")
|
||||
|
||||
print("="*80)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Analyze DiffDock prediction results',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Analyze all results in directory
|
||||
python analyze_results.py results/output_dir/
|
||||
|
||||
# Show only top 5 predictions per complex
|
||||
python analyze_results.py results/ --top 5
|
||||
|
||||
# Filter by confidence threshold
|
||||
python analyze_results.py results/ --threshold 0.0
|
||||
|
||||
# Export to CSV
|
||||
python analyze_results.py results/ --export summary.csv
|
||||
|
||||
# Show top 20 predictions across all complexes
|
||||
python analyze_results.py results/ --best 20
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('results_dir', help='Path to DiffDock results directory')
|
||||
parser.add_argument('--top', '-t', type=int,
|
||||
help='Show only top N predictions per complex')
|
||||
parser.add_argument('--threshold', type=float,
|
||||
help='Minimum confidence threshold')
|
||||
parser.add_argument('--export', '-e', metavar='FILE',
|
||||
help='Export results to CSV file')
|
||||
parser.add_argument('--best', '-b', type=int, metavar='N',
|
||||
help='Show top N predictions across all complexes')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate results directory
|
||||
if not os.path.exists(args.results_dir):
|
||||
print(f"Error: Results directory not found: {args.results_dir}")
|
||||
return 1
|
||||
|
||||
# Parse results
|
||||
print(f"Analyzing results in: {args.results_dir}")
|
||||
results = parse_confidence_scores(args.results_dir)
|
||||
|
||||
if not results:
|
||||
print("No DiffDock results found in directory")
|
||||
return 1
|
||||
|
||||
# Print summary
|
||||
print_summary(results, top_n=args.top, min_confidence=args.threshold)
|
||||
|
||||
# Print top predictions across all complexes
|
||||
if args.best:
|
||||
print_top_predictions(results, args.best)
|
||||
|
||||
# Export to CSV if requested
|
||||
if args.export:
|
||||
export_to_csv(results, args.export)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
254
skills/diffdock/scripts/prepare_batch_csv.py
Executable file
254
skills/diffdock/scripts/prepare_batch_csv.py
Executable file
@@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiffDock Batch CSV Preparation and Validation Script
|
||||
|
||||
This script helps prepare and validate CSV files for DiffDock batch processing.
|
||||
It checks for required columns, validates file paths, and ensures SMILES strings
|
||||
are properly formatted.
|
||||
|
||||
Usage:
|
||||
python prepare_batch_csv.py input.csv --validate
|
||||
python prepare_batch_csv.py --create --output batch_input.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from rdkit import Chem
|
||||
from rdkit import RDLogger
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
RDKIT_AVAILABLE = True
|
||||
except ImportError:
|
||||
RDKIT_AVAILABLE = False
|
||||
print("Warning: RDKit not available. SMILES validation will be skipped.")
|
||||
|
||||
|
||||
def validate_smiles(smiles_string):
|
||||
"""Validate a SMILES string using RDKit."""
|
||||
if not RDKIT_AVAILABLE:
|
||||
return True, "RDKit not available for validation"
|
||||
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles_string)
|
||||
if mol is None:
|
||||
return False, "Invalid SMILES structure"
|
||||
return True, "Valid SMILES"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
def validate_file_path(file_path, base_dir=None):
|
||||
"""Validate that a file path exists."""
|
||||
if pd.isna(file_path) or file_path == "":
|
||||
return True, "Empty (will use protein_sequence)"
|
||||
|
||||
# Handle relative paths
|
||||
if base_dir:
|
||||
full_path = Path(base_dir) / file_path
|
||||
else:
|
||||
full_path = Path(file_path)
|
||||
|
||||
if full_path.exists():
|
||||
return True, f"File exists: {full_path}"
|
||||
else:
|
||||
return False, f"File not found: {full_path}"
|
||||
|
||||
|
||||
def validate_csv(csv_path, base_dir=None):
|
||||
"""
|
||||
Validate a DiffDock batch input CSV file.
|
||||
|
||||
Args:
|
||||
csv_path: Path to CSV file
|
||||
base_dir: Base directory for relative paths (default: CSV directory)
|
||||
|
||||
Returns:
|
||||
bool: True if validation passes
|
||||
list: List of validation messages
|
||||
"""
|
||||
messages = []
|
||||
valid = True
|
||||
|
||||
# Read CSV
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
messages.append(f"✓ Successfully read CSV with {len(df)} rows")
|
||||
except Exception as e:
|
||||
messages.append(f"✗ Error reading CSV: {e}")
|
||||
return False, messages
|
||||
|
||||
# Check required columns
|
||||
required_cols = ['complex_name', 'protein_path', 'ligand_description', 'protein_sequence']
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
|
||||
if missing_cols:
|
||||
messages.append(f"✗ Missing required columns: {', '.join(missing_cols)}")
|
||||
valid = False
|
||||
else:
|
||||
messages.append("✓ All required columns present")
|
||||
|
||||
# Set base directory
|
||||
if base_dir is None:
|
||||
base_dir = Path(csv_path).parent
|
||||
|
||||
# Validate each row
|
||||
for idx, row in df.iterrows():
|
||||
row_msgs = []
|
||||
|
||||
# Check complex name
|
||||
if pd.isna(row['complex_name']) or row['complex_name'] == "":
|
||||
row_msgs.append("Missing complex_name")
|
||||
valid = False
|
||||
|
||||
# Check that either protein_path or protein_sequence is provided
|
||||
has_protein_path = not pd.isna(row['protein_path']) and row['protein_path'] != ""
|
||||
has_protein_seq = not pd.isna(row['protein_sequence']) and row['protein_sequence'] != ""
|
||||
|
||||
if not has_protein_path and not has_protein_seq:
|
||||
row_msgs.append("Must provide either protein_path or protein_sequence")
|
||||
valid = False
|
||||
elif has_protein_path and has_protein_seq:
|
||||
row_msgs.append("Warning: Both protein_path and protein_sequence provided, will use protein_path")
|
||||
|
||||
# Validate protein path if provided
|
||||
if has_protein_path:
|
||||
file_valid, msg = validate_file_path(row['protein_path'], base_dir)
|
||||
if not file_valid:
|
||||
row_msgs.append(f"Protein file issue: {msg}")
|
||||
valid = False
|
||||
|
||||
# Validate ligand description
|
||||
if pd.isna(row['ligand_description']) or row['ligand_description'] == "":
|
||||
row_msgs.append("Missing ligand_description")
|
||||
valid = False
|
||||
else:
|
||||
ligand_desc = row['ligand_description']
|
||||
# Check if it's a file path or SMILES
|
||||
if os.path.exists(ligand_desc) or "/" in ligand_desc or "\\" in ligand_desc:
|
||||
# Likely a file path
|
||||
file_valid, msg = validate_file_path(ligand_desc, base_dir)
|
||||
if not file_valid:
|
||||
row_msgs.append(f"Ligand file issue: {msg}")
|
||||
valid = False
|
||||
else:
|
||||
# Likely a SMILES string
|
||||
smiles_valid, msg = validate_smiles(ligand_desc)
|
||||
if not smiles_valid:
|
||||
row_msgs.append(f"SMILES issue: {msg}")
|
||||
valid = False
|
||||
|
||||
if row_msgs:
|
||||
messages.append(f"\nRow {idx + 1} ({row.get('complex_name', 'unnamed')}):")
|
||||
for msg in row_msgs:
|
||||
messages.append(f" - {msg}")
|
||||
|
||||
# Summary
|
||||
messages.append(f"\n{'='*60}")
|
||||
if valid:
|
||||
messages.append("✓ CSV validation PASSED - ready for DiffDock")
|
||||
else:
|
||||
messages.append("✗ CSV validation FAILED - please fix issues above")
|
||||
|
||||
return valid, messages
|
||||
|
||||
|
||||
def create_template_csv(output_path, num_examples=3):
|
||||
"""Create a template CSV file with example entries."""
|
||||
|
||||
examples = {
|
||||
'complex_name': ['example1', 'example2', 'example3'][:num_examples],
|
||||
'protein_path': ['protein1.pdb', '', 'protein3.pdb'][:num_examples],
|
||||
'ligand_description': [
|
||||
'CC(=O)Oc1ccccc1C(=O)O', # Aspirin SMILES
|
||||
'COc1ccc(C#N)cc1', # Example SMILES
|
||||
'ligand.sdf' # Example file path
|
||||
][:num_examples],
|
||||
'protein_sequence': [
|
||||
'', # Empty - using PDB file
|
||||
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', # GFP sequence
|
||||
'' # Empty - using PDB file
|
||||
][:num_examples]
|
||||
}
|
||||
|
||||
df = pd.DataFrame(examples)
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Prepare and validate DiffDock batch CSV files',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Validate existing CSV
|
||||
python prepare_batch_csv.py input.csv --validate
|
||||
|
||||
# Create template CSV
|
||||
python prepare_batch_csv.py --create --output batch_template.csv
|
||||
|
||||
# Create template with 5 example rows
|
||||
python prepare_batch_csv.py --create --output template.csv --num-examples 5
|
||||
|
||||
# Validate with custom base directory for relative paths
|
||||
python prepare_batch_csv.py input.csv --validate --base-dir /path/to/data/
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('csv_file', nargs='?', help='CSV file to validate')
|
||||
parser.add_argument('--validate', action='store_true',
|
||||
help='Validate the CSV file')
|
||||
parser.add_argument('--create', action='store_true',
|
||||
help='Create a template CSV file')
|
||||
parser.add_argument('--output', '-o', help='Output path for template CSV')
|
||||
parser.add_argument('--num-examples', type=int, default=3,
|
||||
help='Number of example rows in template (default: 3)')
|
||||
parser.add_argument('--base-dir', help='Base directory for relative file paths')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create template
|
||||
if args.create:
|
||||
output_path = args.output or 'diffdock_batch_template.csv'
|
||||
df = create_template_csv(output_path, args.num_examples)
|
||||
print(f"✓ Created template CSV: {output_path}")
|
||||
print(f"\nTemplate contents:")
|
||||
print(df.to_string(index=False))
|
||||
print(f"\nEdit this file with your protein-ligand pairs and run with:")
|
||||
print(f" python -m inference --config default_inference_args.yaml \\")
|
||||
print(f" --protein_ligand_csv {output_path} --out_dir results/")
|
||||
return 0
|
||||
|
||||
# Validate CSV
|
||||
if args.validate or args.csv_file:
|
||||
if not args.csv_file:
|
||||
print("Error: CSV file required for validation")
|
||||
parser.print_help()
|
||||
return 1
|
||||
|
||||
if not os.path.exists(args.csv_file):
|
||||
print(f"Error: CSV file not found: {args.csv_file}")
|
||||
return 1
|
||||
|
||||
print(f"Validating: {args.csv_file}")
|
||||
print("="*60)
|
||||
|
||||
valid, messages = validate_csv(args.csv_file, args.base_dir)
|
||||
|
||||
for msg in messages:
|
||||
print(msg)
|
||||
|
||||
return 0 if valid else 1
|
||||
|
||||
# No action specified
|
||||
parser.print_help()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
278
skills/diffdock/scripts/setup_check.py
Executable file
278
skills/diffdock/scripts/setup_check.py
Executable file
@@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiffDock Environment Setup Checker
|
||||
|
||||
This script verifies that the DiffDock environment is properly configured
|
||||
and all dependencies are available.
|
||||
|
||||
Usage:
|
||||
python setup_check.py
|
||||
python setup_check.py --verbose
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_python_version():
|
||||
"""Check Python version."""
|
||||
import sys
|
||||
version = sys.version_info
|
||||
|
||||
print("Checking Python version...")
|
||||
if version.major == 3 and version.minor >= 8:
|
||||
print(f" ✓ Python {version.major}.{version.minor}.{version.micro}")
|
||||
return True
|
||||
else:
|
||||
print(f" ✗ Python {version.major}.{version.minor}.{version.micro} "
|
||||
f"(requires Python 3.8 or higher)")
|
||||
return False
|
||||
|
||||
|
||||
def check_package(package_name, import_name=None, version_attr='__version__'):
|
||||
"""Check if a Python package is installed."""
|
||||
if import_name is None:
|
||||
import_name = package_name
|
||||
|
||||
try:
|
||||
module = __import__(import_name)
|
||||
version = getattr(module, version_attr, 'unknown')
|
||||
print(f" ✓ {package_name:20s} (version: {version})")
|
||||
return True
|
||||
except ImportError:
|
||||
print(f" ✗ {package_name:20s} (not installed)")
|
||||
return False
|
||||
|
||||
|
||||
def check_pytorch():
|
||||
"""Check PyTorch installation and CUDA availability."""
|
||||
print("\nChecking PyTorch...")
|
||||
try:
|
||||
import torch
|
||||
print(f" ✓ PyTorch version: {torch.__version__}")
|
||||
|
||||
# Check CUDA
|
||||
if torch.cuda.is_available():
|
||||
print(f" ✓ CUDA available: {torch.cuda.get_device_name(0)}")
|
||||
print(f" - CUDA version: {torch.version.cuda}")
|
||||
print(f" - Number of GPUs: {torch.cuda.device_count()}")
|
||||
return True, True
|
||||
else:
|
||||
print(f" ⚠ CUDA not available (will run on CPU)")
|
||||
return True, False
|
||||
except ImportError:
|
||||
print(f" ✗ PyTorch not installed")
|
||||
return False, False
|
||||
|
||||
|
||||
def check_pytorch_geometric():
|
||||
"""Check PyTorch Geometric installation."""
|
||||
print("\nChecking PyTorch Geometric...")
|
||||
packages = [
|
||||
('torch-geometric', 'torch_geometric'),
|
||||
('torch-scatter', 'torch_scatter'),
|
||||
('torch-sparse', 'torch_sparse'),
|
||||
('torch-cluster', 'torch_cluster'),
|
||||
]
|
||||
|
||||
all_ok = True
|
||||
for pkg_name, import_name in packages:
|
||||
if not check_package(pkg_name, import_name):
|
||||
all_ok = False
|
||||
|
||||
return all_ok
|
||||
|
||||
|
||||
def check_core_dependencies():
|
||||
"""Check core DiffDock dependencies."""
|
||||
print("\nChecking core dependencies...")
|
||||
|
||||
dependencies = [
|
||||
('numpy', 'numpy'),
|
||||
('scipy', 'scipy'),
|
||||
('pandas', 'pandas'),
|
||||
('rdkit', 'rdkit', 'rdBase.__version__'),
|
||||
('biopython', 'Bio', '__version__'),
|
||||
('pytorch-lightning', 'pytorch_lightning'),
|
||||
('PyYAML', 'yaml'),
|
||||
]
|
||||
|
||||
all_ok = True
|
||||
for dep in dependencies:
|
||||
pkg_name = dep[0]
|
||||
import_name = dep[1]
|
||||
version_attr = dep[2] if len(dep) > 2 else '__version__'
|
||||
|
||||
if not check_package(pkg_name, import_name, version_attr):
|
||||
all_ok = False
|
||||
|
||||
return all_ok
|
||||
|
||||
|
||||
def check_esm():
|
||||
"""Check ESM (protein language model) installation."""
|
||||
print("\nChecking ESM (for protein sequence folding)...")
|
||||
try:
|
||||
import esm
|
||||
print(f" ✓ ESM installed (version: {esm.__version__ if hasattr(esm, '__version__') else 'unknown'})")
|
||||
return True
|
||||
except ImportError:
|
||||
print(f" ⚠ ESM not installed (needed for protein sequence folding)")
|
||||
print(f" Install with: pip install fair-esm")
|
||||
return False
|
||||
|
||||
|
||||
def check_diffdock_installation():
|
||||
"""Check if DiffDock is properly installed/cloned."""
|
||||
print("\nChecking DiffDock installation...")
|
||||
|
||||
# Look for key files
|
||||
key_files = [
|
||||
'inference.py',
|
||||
'default_inference_args.yaml',
|
||||
'environment.yml',
|
||||
]
|
||||
|
||||
found_files = []
|
||||
missing_files = []
|
||||
|
||||
for filename in key_files:
|
||||
if os.path.exists(filename):
|
||||
found_files.append(filename)
|
||||
else:
|
||||
missing_files.append(filename)
|
||||
|
||||
if found_files:
|
||||
print(f" ✓ Found DiffDock files in current directory:")
|
||||
for f in found_files:
|
||||
print(f" - {f}")
|
||||
else:
|
||||
print(f" ⚠ DiffDock files not found in current directory")
|
||||
print(f" Current directory: {os.getcwd()}")
|
||||
print(f" Make sure you're in the DiffDock repository root")
|
||||
|
||||
# Check for model checkpoints
|
||||
model_dir = Path('./workdir/v1.1/score_model')
|
||||
confidence_dir = Path('./workdir/v1.1/confidence_model')
|
||||
|
||||
if model_dir.exists() and confidence_dir.exists():
|
||||
print(f" ✓ Model checkpoints found")
|
||||
else:
|
||||
print(f" ⚠ Model checkpoints not found in ./workdir/v1.1/")
|
||||
print(f" Models will be downloaded on first run")
|
||||
|
||||
return len(found_files) > 0
|
||||
|
||||
|
||||
def print_installation_instructions():
|
||||
"""Print installation instructions if setup is incomplete."""
|
||||
print("\n" + "="*80)
|
||||
print("Installation Instructions")
|
||||
print("="*80)
|
||||
|
||||
print("""
|
||||
If DiffDock is not installed, follow these steps:
|
||||
|
||||
1. Clone the repository:
|
||||
git clone https://github.com/gcorso/DiffDock.git
|
||||
cd DiffDock
|
||||
|
||||
2. Create conda environment:
|
||||
conda env create --file environment.yml
|
||||
conda activate diffdock
|
||||
|
||||
3. Verify installation:
|
||||
python setup_check.py
|
||||
|
||||
For Docker installation:
|
||||
docker pull rbgcsail/diffdock
|
||||
docker run -it --gpus all --entrypoint /bin/bash rbgcsail/diffdock
|
||||
micromamba activate diffdock
|
||||
|
||||
For more information, visit: https://github.com/gcorso/DiffDock
|
||||
""")
|
||||
|
||||
|
||||
def print_performance_notes(has_cuda):
|
||||
"""Print performance notes based on available hardware."""
|
||||
print("\n" + "="*80)
|
||||
print("Performance Notes")
|
||||
print("="*80)
|
||||
|
||||
if has_cuda:
|
||||
print("""
|
||||
✓ GPU detected - DiffDock will run efficiently
|
||||
|
||||
Expected performance:
|
||||
- First run: ~2-5 minutes (pre-computing SO(2)/SO(3) tables)
|
||||
- Subsequent runs: ~10-60 seconds per complex (depending on settings)
|
||||
- Batch processing: Highly efficient with GPU
|
||||
""")
|
||||
else:
|
||||
print("""
|
||||
⚠ No GPU detected - DiffDock will run on CPU
|
||||
|
||||
Expected performance:
|
||||
- CPU inference is SIGNIFICANTLY slower than GPU
|
||||
- Single complex: Several minutes to hours
|
||||
- Batch processing: Not recommended on CPU
|
||||
|
||||
Recommendation: Use GPU for practical applications
|
||||
- Cloud options: Google Colab, AWS, or other cloud GPU services
|
||||
- Local: Install CUDA-capable GPU
|
||||
""")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Check DiffDock environment setup',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument('--verbose', '-v', action='store_true',
|
||||
help='Show detailed version information')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("="*80)
|
||||
print("DiffDock Environment Setup Checker")
|
||||
print("="*80)
|
||||
|
||||
checks = []
|
||||
|
||||
# Run all checks
|
||||
checks.append(("Python version", check_python_version()))
|
||||
|
||||
pytorch_ok, has_cuda = check_pytorch()
|
||||
checks.append(("PyTorch", pytorch_ok))
|
||||
|
||||
checks.append(("PyTorch Geometric", check_pytorch_geometric()))
|
||||
checks.append(("Core dependencies", check_core_dependencies()))
|
||||
checks.append(("ESM", check_esm()))
|
||||
checks.append(("DiffDock files", check_diffdock_installation()))
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*80)
|
||||
print("Summary")
|
||||
print("="*80)
|
||||
|
||||
all_passed = all(result for _, result in checks)
|
||||
|
||||
for check_name, result in checks:
|
||||
status = "✓ PASS" if result else "✗ FAIL"
|
||||
print(f" {status:8s} - {check_name}")
|
||||
|
||||
if all_passed:
|
||||
print("\n✓ All checks passed! DiffDock is ready to use.")
|
||||
print_performance_notes(has_cuda)
|
||||
return 0
|
||||
else:
|
||||
print("\n✗ Some checks failed. Please install missing dependencies.")
|
||||
print_installation_instructions()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user