279 lines
7.8 KiB
Python
Executable File
279 lines
7.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
DiffDock Environment Setup Checker
|
|
|
|
This script verifies that the DiffDock environment is properly configured
|
|
and all dependencies are available.
|
|
|
|
Usage:
|
|
python setup_check.py
|
|
python setup_check.py --verbose
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
import os
|
|
from pathlib import Path
|
|
|
|
|
|
def check_python_version():
|
|
"""Check Python version."""
|
|
import sys
|
|
version = sys.version_info
|
|
|
|
print("Checking Python version...")
|
|
if version.major == 3 and version.minor >= 8:
|
|
print(f" ✓ Python {version.major}.{version.minor}.{version.micro}")
|
|
return True
|
|
else:
|
|
print(f" ✗ Python {version.major}.{version.minor}.{version.micro} "
|
|
f"(requires Python 3.8 or higher)")
|
|
return False
|
|
|
|
|
|
def check_package(package_name, import_name=None, version_attr='__version__'):
|
|
"""Check if a Python package is installed."""
|
|
if import_name is None:
|
|
import_name = package_name
|
|
|
|
try:
|
|
module = __import__(import_name)
|
|
version = getattr(module, version_attr, 'unknown')
|
|
print(f" ✓ {package_name:20s} (version: {version})")
|
|
return True
|
|
except ImportError:
|
|
print(f" ✗ {package_name:20s} (not installed)")
|
|
return False
|
|
|
|
|
|
def check_pytorch():
|
|
"""Check PyTorch installation and CUDA availability."""
|
|
print("\nChecking PyTorch...")
|
|
try:
|
|
import torch
|
|
print(f" ✓ PyTorch version: {torch.__version__}")
|
|
|
|
# Check CUDA
|
|
if torch.cuda.is_available():
|
|
print(f" ✓ CUDA available: {torch.cuda.get_device_name(0)}")
|
|
print(f" - CUDA version: {torch.version.cuda}")
|
|
print(f" - Number of GPUs: {torch.cuda.device_count()}")
|
|
return True, True
|
|
else:
|
|
print(f" ⚠ CUDA not available (will run on CPU)")
|
|
return True, False
|
|
except ImportError:
|
|
print(f" ✗ PyTorch not installed")
|
|
return False, False
|
|
|
|
|
|
def check_pytorch_geometric():
|
|
"""Check PyTorch Geometric installation."""
|
|
print("\nChecking PyTorch Geometric...")
|
|
packages = [
|
|
('torch-geometric', 'torch_geometric'),
|
|
('torch-scatter', 'torch_scatter'),
|
|
('torch-sparse', 'torch_sparse'),
|
|
('torch-cluster', 'torch_cluster'),
|
|
]
|
|
|
|
all_ok = True
|
|
for pkg_name, import_name in packages:
|
|
if not check_package(pkg_name, import_name):
|
|
all_ok = False
|
|
|
|
return all_ok
|
|
|
|
|
|
def check_core_dependencies():
|
|
"""Check core DiffDock dependencies."""
|
|
print("\nChecking core dependencies...")
|
|
|
|
dependencies = [
|
|
('numpy', 'numpy'),
|
|
('scipy', 'scipy'),
|
|
('pandas', 'pandas'),
|
|
('rdkit', 'rdkit', 'rdBase.__version__'),
|
|
('biopython', 'Bio', '__version__'),
|
|
('pytorch-lightning', 'pytorch_lightning'),
|
|
('PyYAML', 'yaml'),
|
|
]
|
|
|
|
all_ok = True
|
|
for dep in dependencies:
|
|
pkg_name = dep[0]
|
|
import_name = dep[1]
|
|
version_attr = dep[2] if len(dep) > 2 else '__version__'
|
|
|
|
if not check_package(pkg_name, import_name, version_attr):
|
|
all_ok = False
|
|
|
|
return all_ok
|
|
|
|
|
|
def check_esm():
|
|
"""Check ESM (protein language model) installation."""
|
|
print("\nChecking ESM (for protein sequence folding)...")
|
|
try:
|
|
import esm
|
|
print(f" ✓ ESM installed (version: {esm.__version__ if hasattr(esm, '__version__') else 'unknown'})")
|
|
return True
|
|
except ImportError:
|
|
print(f" ⚠ ESM not installed (needed for protein sequence folding)")
|
|
print(f" Install with: pip install fair-esm")
|
|
return False
|
|
|
|
|
|
def check_diffdock_installation():
|
|
"""Check if DiffDock is properly installed/cloned."""
|
|
print("\nChecking DiffDock installation...")
|
|
|
|
# Look for key files
|
|
key_files = [
|
|
'inference.py',
|
|
'default_inference_args.yaml',
|
|
'environment.yml',
|
|
]
|
|
|
|
found_files = []
|
|
missing_files = []
|
|
|
|
for filename in key_files:
|
|
if os.path.exists(filename):
|
|
found_files.append(filename)
|
|
else:
|
|
missing_files.append(filename)
|
|
|
|
if found_files:
|
|
print(f" ✓ Found DiffDock files in current directory:")
|
|
for f in found_files:
|
|
print(f" - {f}")
|
|
else:
|
|
print(f" ⚠ DiffDock files not found in current directory")
|
|
print(f" Current directory: {os.getcwd()}")
|
|
print(f" Make sure you're in the DiffDock repository root")
|
|
|
|
# Check for model checkpoints
|
|
model_dir = Path('./workdir/v1.1/score_model')
|
|
confidence_dir = Path('./workdir/v1.1/confidence_model')
|
|
|
|
if model_dir.exists() and confidence_dir.exists():
|
|
print(f" ✓ Model checkpoints found")
|
|
else:
|
|
print(f" ⚠ Model checkpoints not found in ./workdir/v1.1/")
|
|
print(f" Models will be downloaded on first run")
|
|
|
|
return len(found_files) > 0
|
|
|
|
|
|
def print_installation_instructions():
|
|
"""Print installation instructions if setup is incomplete."""
|
|
print("\n" + "="*80)
|
|
print("Installation Instructions")
|
|
print("="*80)
|
|
|
|
print("""
|
|
If DiffDock is not installed, follow these steps:
|
|
|
|
1. Clone the repository:
|
|
git clone https://github.com/gcorso/DiffDock.git
|
|
cd DiffDock
|
|
|
|
2. Create conda environment:
|
|
conda env create --file environment.yml
|
|
conda activate diffdock
|
|
|
|
3. Verify installation:
|
|
python setup_check.py
|
|
|
|
For Docker installation:
|
|
docker pull rbgcsail/diffdock
|
|
docker run -it --gpus all --entrypoint /bin/bash rbgcsail/diffdock
|
|
micromamba activate diffdock
|
|
|
|
For more information, visit: https://github.com/gcorso/DiffDock
|
|
""")
|
|
|
|
|
|
def print_performance_notes(has_cuda):
|
|
"""Print performance notes based on available hardware."""
|
|
print("\n" + "="*80)
|
|
print("Performance Notes")
|
|
print("="*80)
|
|
|
|
if has_cuda:
|
|
print("""
|
|
✓ GPU detected - DiffDock will run efficiently
|
|
|
|
Expected performance:
|
|
- First run: ~2-5 minutes (pre-computing SO(2)/SO(3) tables)
|
|
- Subsequent runs: ~10-60 seconds per complex (depending on settings)
|
|
- Batch processing: Highly efficient with GPU
|
|
""")
|
|
else:
|
|
print("""
|
|
⚠ No GPU detected - DiffDock will run on CPU
|
|
|
|
Expected performance:
|
|
- CPU inference is SIGNIFICANTLY slower than GPU
|
|
- Single complex: Several minutes to hours
|
|
- Batch processing: Not recommended on CPU
|
|
|
|
Recommendation: Use GPU for practical applications
|
|
- Cloud options: Google Colab, AWS, or other cloud GPU services
|
|
- Local: Install CUDA-capable GPU
|
|
""")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Check DiffDock environment setup',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
|
|
parser.add_argument('--verbose', '-v', action='store_true',
|
|
help='Show detailed version information')
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("="*80)
|
|
print("DiffDock Environment Setup Checker")
|
|
print("="*80)
|
|
|
|
checks = []
|
|
|
|
# Run all checks
|
|
checks.append(("Python version", check_python_version()))
|
|
|
|
pytorch_ok, has_cuda = check_pytorch()
|
|
checks.append(("PyTorch", pytorch_ok))
|
|
|
|
checks.append(("PyTorch Geometric", check_pytorch_geometric()))
|
|
checks.append(("Core dependencies", check_core_dependencies()))
|
|
checks.append(("ESM", check_esm()))
|
|
checks.append(("DiffDock files", check_diffdock_installation()))
|
|
|
|
# Summary
|
|
print("\n" + "="*80)
|
|
print("Summary")
|
|
print("="*80)
|
|
|
|
all_passed = all(result for _, result in checks)
|
|
|
|
for check_name, result in checks:
|
|
status = "✓ PASS" if result else "✗ FAIL"
|
|
print(f" {status:8s} - {check_name}")
|
|
|
|
if all_passed:
|
|
print("\n✓ All checks passed! DiffDock is ready to use.")
|
|
print_performance_notes(has_cuda)
|
|
return 0
|
|
else:
|
|
print("\n✗ Some checks failed. Please install missing dependencies.")
|
|
print_installation_instructions()
|
|
return 1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(main())
|