Initial commit
This commit is contained in:
418
skills/medchem/scripts/filter_molecules.py
Normal file
418
skills/medchem/scripts/filter_molecules.py
Normal file
@@ -0,0 +1,418 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Batch molecular filtering using medchem library.
|
||||
|
||||
This script provides a production-ready workflow for filtering compound libraries
|
||||
using medchem rules, structural alerts, and custom constraints.
|
||||
|
||||
Usage:
|
||||
python filter_molecules.py input.csv --rules rule_of_five,rule_of_cns --alerts nibr --output filtered.csv
|
||||
python filter_molecules.py input.sdf --rules rule_of_drug --lilly --complexity 400 --output results.csv
|
||||
python filter_molecules.py smiles.txt --nibr --pains --n-jobs -1 --output clean.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import json
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
import datamol as dm
|
||||
import medchem as mc
|
||||
from rdkit import Chem
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
print(f"Error: Missing required package: {e}")
|
||||
print("Install dependencies: pip install medchem datamol pandas tqdm")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_molecules(input_file: Path, smiles_column: str = "smiles") -> Tuple[pd.DataFrame, List[Chem.Mol]]:
|
||||
"""
|
||||
Load molecules from various file formats.
|
||||
|
||||
Supports:
|
||||
- CSV/TSV with SMILES column
|
||||
- SDF files
|
||||
- Plain text files with one SMILES per line
|
||||
|
||||
Returns:
|
||||
Tuple of (DataFrame with metadata, list of RDKit molecules)
|
||||
"""
|
||||
suffix = input_file.suffix.lower()
|
||||
|
||||
if suffix == ".sdf":
|
||||
print(f"Loading SDF file: {input_file}")
|
||||
supplier = Chem.SDMolSupplier(str(input_file))
|
||||
mols = [mol for mol in supplier if mol is not None]
|
||||
|
||||
# Create DataFrame from SDF properties
|
||||
data = []
|
||||
for mol in mols:
|
||||
props = mol.GetPropsAsDict()
|
||||
props["smiles"] = Chem.MolToSmiles(mol)
|
||||
data.append(props)
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
elif suffix in [".csv", ".tsv"]:
|
||||
print(f"Loading CSV/TSV file: {input_file}")
|
||||
sep = "\t" if suffix == ".tsv" else ","
|
||||
df = pd.read_csv(input_file, sep=sep)
|
||||
|
||||
if smiles_column not in df.columns:
|
||||
print(f"Error: Column '{smiles_column}' not found in file")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Converting SMILES to molecules...")
|
||||
mols = [dm.to_mol(smi) for smi in tqdm(df[smiles_column], desc="Parsing")]
|
||||
|
||||
elif suffix == ".txt":
|
||||
print(f"Loading text file: {input_file}")
|
||||
with open(input_file) as f:
|
||||
smiles_list = [line.strip() for line in f if line.strip()]
|
||||
|
||||
df = pd.DataFrame({"smiles": smiles_list})
|
||||
print(f"Converting SMILES to molecules...")
|
||||
mols = [dm.to_mol(smi) for smi in tqdm(smiles_list, desc="Parsing")]
|
||||
|
||||
else:
|
||||
print(f"Error: Unsupported file format: {suffix}")
|
||||
print("Supported formats: .csv, .tsv, .sdf, .txt")
|
||||
sys.exit(1)
|
||||
|
||||
# Filter out invalid molecules
|
||||
valid_indices = [i for i, mol in enumerate(mols) if mol is not None]
|
||||
if len(valid_indices) < len(mols):
|
||||
n_invalid = len(mols) - len(valid_indices)
|
||||
print(f"Warning: {n_invalid} invalid molecules removed")
|
||||
df = df.iloc[valid_indices].reset_index(drop=True)
|
||||
mols = [mols[i] for i in valid_indices]
|
||||
|
||||
print(f"Loaded {len(mols)} valid molecules")
|
||||
return df, mols
|
||||
|
||||
|
||||
def apply_rule_filters(mols: List[Chem.Mol], rules: List[str], n_jobs: int) -> pd.DataFrame:
|
||||
"""Apply medicinal chemistry rule filters."""
|
||||
print(f"\nApplying rule filters: {', '.join(rules)}")
|
||||
|
||||
rfilter = mc.rules.RuleFilters(rule_list=rules)
|
||||
results = rfilter(mols=mols, n_jobs=n_jobs, progress=True)
|
||||
|
||||
# Convert to DataFrame
|
||||
df_results = pd.DataFrame(results)
|
||||
|
||||
# Add summary column
|
||||
df_results["passes_all_rules"] = df_results.all(axis=1)
|
||||
|
||||
return df_results
|
||||
|
||||
|
||||
def apply_structural_alerts(mols: List[Chem.Mol], alert_type: str, n_jobs: int) -> pd.DataFrame:
|
||||
"""Apply structural alert filters."""
|
||||
print(f"\nApplying {alert_type} structural alerts...")
|
||||
|
||||
if alert_type == "common":
|
||||
alert_filter = mc.structural.CommonAlertsFilters()
|
||||
results = alert_filter(mols=mols, n_jobs=n_jobs, progress=True)
|
||||
|
||||
df_results = pd.DataFrame({
|
||||
"has_common_alerts": [r["has_alerts"] for r in results],
|
||||
"num_common_alerts": [r["num_alerts"] for r in results],
|
||||
"common_alert_details": [", ".join(r["alert_details"]) if r["alert_details"] else "" for r in results]
|
||||
})
|
||||
|
||||
elif alert_type == "nibr":
|
||||
nibr_filter = mc.structural.NIBRFilters()
|
||||
results = nibr_filter(mols=mols, n_jobs=n_jobs, progress=True)
|
||||
|
||||
df_results = pd.DataFrame({
|
||||
"passes_nibr": results
|
||||
})
|
||||
|
||||
elif alert_type == "lilly":
|
||||
lilly_filter = mc.structural.LillyDemeritsFilters()
|
||||
results = lilly_filter(mols=mols, n_jobs=n_jobs, progress=True)
|
||||
|
||||
df_results = pd.DataFrame({
|
||||
"lilly_demerits": [r["demerits"] for r in results],
|
||||
"passes_lilly": [r["passes"] for r in results],
|
||||
"lilly_patterns": [", ".join([p["pattern"] for p in r["matched_patterns"]]) for r in results]
|
||||
})
|
||||
|
||||
elif alert_type == "pains":
|
||||
results = [mc.rules.basic_rules.pains_filter(mol) for mol in tqdm(mols, desc="PAINS")]
|
||||
|
||||
df_results = pd.DataFrame({
|
||||
"passes_pains": results
|
||||
})
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown alert type: {alert_type}")
|
||||
|
||||
return df_results
|
||||
|
||||
|
||||
def apply_complexity_filter(mols: List[Chem.Mol], max_complexity: float, method: str = "bertz") -> pd.DataFrame:
|
||||
"""Calculate molecular complexity."""
|
||||
print(f"\nCalculating molecular complexity (method={method}, max={max_complexity})...")
|
||||
|
||||
complexity_scores = [
|
||||
mc.complexity.calculate_complexity(mol, method=method)
|
||||
for mol in tqdm(mols, desc="Complexity")
|
||||
]
|
||||
|
||||
df_results = pd.DataFrame({
|
||||
"complexity_score": complexity_scores,
|
||||
"passes_complexity": [score <= max_complexity for score in complexity_scores]
|
||||
})
|
||||
|
||||
return df_results
|
||||
|
||||
|
||||
def apply_constraints(mols: List[Chem.Mol], constraints: Dict, n_jobs: int) -> pd.DataFrame:
|
||||
"""Apply custom property constraints."""
|
||||
print(f"\nApplying constraints: {constraints}")
|
||||
|
||||
constraint_filter = mc.constraints.Constraints(**constraints)
|
||||
results = constraint_filter(mols=mols, n_jobs=n_jobs, progress=True)
|
||||
|
||||
df_results = pd.DataFrame({
|
||||
"passes_constraints": [r["passes"] for r in results],
|
||||
"constraint_violations": [", ".join(r["violations"]) if r["violations"] else "" for r in results]
|
||||
})
|
||||
|
||||
return df_results
|
||||
|
||||
|
||||
def apply_chemical_groups(mols: List[Chem.Mol], groups: List[str]) -> pd.DataFrame:
|
||||
"""Detect chemical groups."""
|
||||
print(f"\nDetecting chemical groups: {', '.join(groups)}")
|
||||
|
||||
group_detector = mc.groups.ChemicalGroup(groups=groups)
|
||||
results = group_detector.get_all_matches(mols)
|
||||
|
||||
df_results = pd.DataFrame()
|
||||
for group in groups:
|
||||
df_results[f"has_{group}"] = [bool(r.get(group)) for r in results]
|
||||
|
||||
return df_results
|
||||
|
||||
|
||||
def generate_summary(df: pd.DataFrame, output_file: Path):
|
||||
"""Generate filtering summary report."""
|
||||
summary_file = output_file.parent / f"{output_file.stem}_summary.txt"
|
||||
|
||||
with open(summary_file, "w") as f:
|
||||
f.write("=" * 80 + "\n")
|
||||
f.write("MEDCHEM FILTERING SUMMARY\n")
|
||||
f.write("=" * 80 + "\n\n")
|
||||
|
||||
f.write(f"Total molecules processed: {len(df)}\n\n")
|
||||
|
||||
# Rule results
|
||||
rule_cols = [col for col in df.columns if col.startswith("rule_") or col == "passes_all_rules"]
|
||||
if rule_cols:
|
||||
f.write("RULE FILTERS:\n")
|
||||
f.write("-" * 40 + "\n")
|
||||
for col in rule_cols:
|
||||
if col in df.columns and df[col].dtype == bool:
|
||||
n_pass = df[col].sum()
|
||||
pct = 100 * n_pass / len(df)
|
||||
f.write(f" {col}: {n_pass} passed ({pct:.1f}%)\n")
|
||||
f.write("\n")
|
||||
|
||||
# Structural alerts
|
||||
alert_cols = [col for col in df.columns if "alert" in col.lower() or "nibr" in col.lower() or "lilly" in col.lower() or "pains" in col.lower()]
|
||||
if alert_cols:
|
||||
f.write("STRUCTURAL ALERTS:\n")
|
||||
f.write("-" * 40 + "\n")
|
||||
if "has_common_alerts" in df.columns:
|
||||
n_clean = (~df["has_common_alerts"]).sum()
|
||||
pct = 100 * n_clean / len(df)
|
||||
f.write(f" No common alerts: {n_clean} ({pct:.1f}%)\n")
|
||||
if "passes_nibr" in df.columns:
|
||||
n_pass = df["passes_nibr"].sum()
|
||||
pct = 100 * n_pass / len(df)
|
||||
f.write(f" Passes NIBR: {n_pass} ({pct:.1f}%)\n")
|
||||
if "passes_lilly" in df.columns:
|
||||
n_pass = df["passes_lilly"].sum()
|
||||
pct = 100 * n_pass / len(df)
|
||||
f.write(f" Passes Lilly: {n_pass} ({pct:.1f}%)\n")
|
||||
avg_demerits = df["lilly_demerits"].mean()
|
||||
f.write(f" Average Lilly demerits: {avg_demerits:.1f}\n")
|
||||
if "passes_pains" in df.columns:
|
||||
n_pass = df["passes_pains"].sum()
|
||||
pct = 100 * n_pass / len(df)
|
||||
f.write(f" Passes PAINS: {n_pass} ({pct:.1f}%)\n")
|
||||
f.write("\n")
|
||||
|
||||
# Complexity
|
||||
if "complexity_score" in df.columns:
|
||||
f.write("COMPLEXITY:\n")
|
||||
f.write("-" * 40 + "\n")
|
||||
avg_complexity = df["complexity_score"].mean()
|
||||
f.write(f" Average complexity: {avg_complexity:.1f}\n")
|
||||
if "passes_complexity" in df.columns:
|
||||
n_pass = df["passes_complexity"].sum()
|
||||
pct = 100 * n_pass / len(df)
|
||||
f.write(f" Within threshold: {n_pass} ({pct:.1f}%)\n")
|
||||
f.write("\n")
|
||||
|
||||
# Constraints
|
||||
if "passes_constraints" in df.columns:
|
||||
f.write("CONSTRAINTS:\n")
|
||||
f.write("-" * 40 + "\n")
|
||||
n_pass = df["passes_constraints"].sum()
|
||||
pct = 100 * n_pass / len(df)
|
||||
f.write(f" Passes all constraints: {n_pass} ({pct:.1f}%)\n")
|
||||
f.write("\n")
|
||||
|
||||
# Overall pass rate
|
||||
pass_cols = [col for col in df.columns if col.startswith("passes_")]
|
||||
if pass_cols:
|
||||
df["passes_all_filters"] = df[pass_cols].all(axis=1)
|
||||
n_pass = df["passes_all_filters"].sum()
|
||||
pct = 100 * n_pass / len(df)
|
||||
f.write("OVERALL:\n")
|
||||
f.write("-" * 40 + "\n")
|
||||
f.write(f" Molecules passing all filters: {n_pass} ({pct:.1f}%)\n")
|
||||
|
||||
f.write("\n" + "=" * 80 + "\n")
|
||||
|
||||
print(f"\nSummary report saved to: {summary_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Batch molecular filtering using medchem",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__
|
||||
)
|
||||
|
||||
# Input/Output
|
||||
parser.add_argument("input", type=Path, help="Input file (CSV, TSV, SDF, or TXT)")
|
||||
parser.add_argument("--output", "-o", type=Path, required=True, help="Output CSV file")
|
||||
parser.add_argument("--smiles-column", default="smiles", help="Name of SMILES column (default: smiles)")
|
||||
|
||||
# Rule filters
|
||||
parser.add_argument("--rules", help="Comma-separated list of rules (e.g., rule_of_five,rule_of_cns)")
|
||||
|
||||
# Structural alerts
|
||||
parser.add_argument("--common-alerts", action="store_true", help="Apply common structural alerts")
|
||||
parser.add_argument("--nibr", action="store_true", help="Apply NIBR filters")
|
||||
parser.add_argument("--lilly", action="store_true", help="Apply Lilly demerits filter")
|
||||
parser.add_argument("--pains", action="store_true", help="Apply PAINS filter")
|
||||
|
||||
# Complexity
|
||||
parser.add_argument("--complexity", type=float, help="Maximum complexity threshold")
|
||||
parser.add_argument("--complexity-method", default="bertz", choices=["bertz", "whitlock", "barone"],
|
||||
help="Complexity calculation method")
|
||||
|
||||
# Constraints
|
||||
parser.add_argument("--mw-range", help="Molecular weight range (e.g., 200,500)")
|
||||
parser.add_argument("--logp-range", help="LogP range (e.g., -2,5)")
|
||||
parser.add_argument("--tpsa-max", type=float, help="Maximum TPSA")
|
||||
parser.add_argument("--hbd-max", type=int, help="Maximum H-bond donors")
|
||||
parser.add_argument("--hba-max", type=int, help="Maximum H-bond acceptors")
|
||||
parser.add_argument("--rotatable-bonds-max", type=int, help="Maximum rotatable bonds")
|
||||
|
||||
# Chemical groups
|
||||
parser.add_argument("--groups", help="Comma-separated chemical groups to detect")
|
||||
|
||||
# Processing options
|
||||
parser.add_argument("--n-jobs", type=int, default=-1, help="Number of parallel jobs (-1 = all cores)")
|
||||
parser.add_argument("--no-summary", action="store_true", help="Don't generate summary report")
|
||||
parser.add_argument("--filter-output", action="store_true", help="Only output molecules passing all filters")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load molecules
|
||||
df, mols = load_molecules(args.input, args.smiles_column)
|
||||
|
||||
# Apply filters
|
||||
result_dfs = [df]
|
||||
|
||||
# Rules
|
||||
if args.rules:
|
||||
rule_list = [r.strip() for r in args.rules.split(",")]
|
||||
df_rules = apply_rule_filters(mols, rule_list, args.n_jobs)
|
||||
result_dfs.append(df_rules)
|
||||
|
||||
# Structural alerts
|
||||
if args.common_alerts:
|
||||
df_alerts = apply_structural_alerts(mols, "common", args.n_jobs)
|
||||
result_dfs.append(df_alerts)
|
||||
|
||||
if args.nibr:
|
||||
df_nibr = apply_structural_alerts(mols, "nibr", args.n_jobs)
|
||||
result_dfs.append(df_nibr)
|
||||
|
||||
if args.lilly:
|
||||
df_lilly = apply_structural_alerts(mols, "lilly", args.n_jobs)
|
||||
result_dfs.append(df_lilly)
|
||||
|
||||
if args.pains:
|
||||
df_pains = apply_structural_alerts(mols, "pains", args.n_jobs)
|
||||
result_dfs.append(df_pains)
|
||||
|
||||
# Complexity
|
||||
if args.complexity:
|
||||
df_complexity = apply_complexity_filter(mols, args.complexity, args.complexity_method)
|
||||
result_dfs.append(df_complexity)
|
||||
|
||||
# Constraints
|
||||
constraints = {}
|
||||
if args.mw_range:
|
||||
mw_min, mw_max = map(float, args.mw_range.split(","))
|
||||
constraints["mw_range"] = (mw_min, mw_max)
|
||||
if args.logp_range:
|
||||
logp_min, logp_max = map(float, args.logp_range.split(","))
|
||||
constraints["logp_range"] = (logp_min, logp_max)
|
||||
if args.tpsa_max:
|
||||
constraints["tpsa_max"] = args.tpsa_max
|
||||
if args.hbd_max:
|
||||
constraints["hbd_max"] = args.hbd_max
|
||||
if args.hba_max:
|
||||
constraints["hba_max"] = args.hba_max
|
||||
if args.rotatable_bonds_max:
|
||||
constraints["rotatable_bonds_max"] = args.rotatable_bonds_max
|
||||
|
||||
if constraints:
|
||||
df_constraints = apply_constraints(mols, constraints, args.n_jobs)
|
||||
result_dfs.append(df_constraints)
|
||||
|
||||
# Chemical groups
|
||||
if args.groups:
|
||||
group_list = [g.strip() for g in args.groups.split(",")]
|
||||
df_groups = apply_chemical_groups(mols, group_list)
|
||||
result_dfs.append(df_groups)
|
||||
|
||||
# Combine results
|
||||
df_final = pd.concat(result_dfs, axis=1)
|
||||
|
||||
# Filter output if requested
|
||||
if args.filter_output:
|
||||
pass_cols = [col for col in df_final.columns if col.startswith("passes_")]
|
||||
if pass_cols:
|
||||
df_final["passes_all"] = df_final[pass_cols].all(axis=1)
|
||||
df_final = df_final[df_final["passes_all"]]
|
||||
print(f"\nFiltered to {len(df_final)} molecules passing all filters")
|
||||
|
||||
# Save results
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
df_final.to_csv(args.output, index=False)
|
||||
print(f"\nResults saved to: {args.output}")
|
||||
|
||||
# Generate summary
|
||||
if not args.no_summary:
|
||||
generate_summary(df_final, args.output)
|
||||
|
||||
print("\nDone!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user