Initial commit
This commit is contained in:
384
skills/clinical-decision-support/scripts/biomarker_classifier.py
Executable file
384
skills/clinical-decision-support/scripts/biomarker_classifier.py
Executable file
@@ -0,0 +1,384 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Biomarker-Based Patient Stratification and Classification
|
||||
|
||||
Performs patient stratification based on biomarker profiles with:
|
||||
- Binary classification (biomarker+/-)
|
||||
- Multi-class molecular subtypes
|
||||
- Continuous biomarker scoring
|
||||
- Correlation with clinical outcomes
|
||||
|
||||
Dependencies: pandas, numpy, scipy, scikit-learn (optional for clustering)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def classify_binary_biomarker(data, biomarker_col, threshold,
|
||||
above_label='Biomarker+', below_label='Biomarker-'):
|
||||
"""
|
||||
Binary classification based on biomarker threshold.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame
|
||||
biomarker_col: Column name for biomarker values
|
||||
threshold: Cut-point value
|
||||
above_label: Label for values >= threshold
|
||||
below_label: Label for values < threshold
|
||||
|
||||
Returns:
|
||||
DataFrame with added 'biomarker_class' column
|
||||
"""
|
||||
|
||||
data = data.copy()
|
||||
data['biomarker_class'] = data[biomarker_col].apply(
|
||||
lambda x: above_label if x >= threshold else below_label
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def classify_pd_l1_tps(data, pd_l1_col='pd_l1_tps'):
|
||||
"""
|
||||
Classify PD-L1 Tumor Proportion Score into clinical categories.
|
||||
|
||||
Categories:
|
||||
- Negative: <1%
|
||||
- Low: 1-49%
|
||||
- High: >=50%
|
||||
|
||||
Returns:
|
||||
DataFrame with 'pd_l1_category' column
|
||||
"""
|
||||
|
||||
data = data.copy()
|
||||
|
||||
def categorize(tps):
|
||||
if tps < 1:
|
||||
return 'PD-L1 Negative (<1%)'
|
||||
elif tps < 50:
|
||||
return 'PD-L1 Low (1-49%)'
|
||||
else:
|
||||
return 'PD-L1 High (≥50%)'
|
||||
|
||||
data['pd_l1_category'] = data[pd_l1_col].apply(categorize)
|
||||
|
||||
# Distribution
|
||||
print("\nPD-L1 TPS Distribution:")
|
||||
print(data['pd_l1_category'].value_counts())
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def classify_her2_status(data, ihc_col='her2_ihc', fish_col='her2_fish'):
|
||||
"""
|
||||
Classify HER2 status based on IHC and FISH results (ASCO/CAP guidelines).
|
||||
|
||||
IHC Scores: 0, 1+, 2+, 3+
|
||||
FISH: Positive, Negative (if IHC 2+)
|
||||
|
||||
Classification:
|
||||
- HER2-positive: IHC 3+ OR IHC 2+/FISH+
|
||||
- HER2-negative: IHC 0/1+ OR IHC 2+/FISH-
|
||||
- HER2-low: IHC 1+ or IHC 2+/FISH- (subset of HER2-negative)
|
||||
|
||||
Returns:
|
||||
DataFrame with 'her2_status' and 'her2_low' columns
|
||||
"""
|
||||
|
||||
data = data.copy()
|
||||
|
||||
def classify_her2(row):
|
||||
ihc = row[ihc_col]
|
||||
fish = row.get(fish_col, None)
|
||||
|
||||
if ihc == '3+':
|
||||
status = 'HER2-positive'
|
||||
her2_low = False
|
||||
elif ihc == '2+':
|
||||
if fish == 'Positive':
|
||||
status = 'HER2-positive'
|
||||
her2_low = False
|
||||
elif fish == 'Negative':
|
||||
status = 'HER2-negative'
|
||||
her2_low = True # HER2-low
|
||||
else:
|
||||
status = 'HER2-equivocal (FISH needed)'
|
||||
her2_low = False
|
||||
elif ihc == '1+':
|
||||
status = 'HER2-negative'
|
||||
her2_low = True # HER2-low
|
||||
else: # IHC 0
|
||||
status = 'HER2-negative'
|
||||
her2_low = False
|
||||
|
||||
return pd.Series({'her2_status': status, 'her2_low': her2_low})
|
||||
|
||||
data[['her2_status', 'her2_low']] = data.apply(classify_her2, axis=1)
|
||||
|
||||
print("\nHER2 Status Distribution:")
|
||||
print(data['her2_status'].value_counts())
|
||||
print(f"\nHER2-low (IHC 1+ or 2+/FISH-): {data['her2_low'].sum()} patients")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def classify_breast_cancer_subtype(data, er_col='er_positive', pr_col='pr_positive',
|
||||
her2_col='her2_positive'):
|
||||
"""
|
||||
Classify breast cancer into molecular subtypes.
|
||||
|
||||
Subtypes:
|
||||
- HR+/HER2-: Luminal (ER+ and/or PR+, HER2-)
|
||||
- HER2+: Any HER2-positive (regardless of HR status)
|
||||
- Triple-negative: ER-, PR-, HER2-
|
||||
|
||||
Returns:
|
||||
DataFrame with 'bc_subtype' column
|
||||
"""
|
||||
|
||||
data = data.copy()
|
||||
|
||||
def get_subtype(row):
|
||||
er = row[er_col]
|
||||
pr = row[pr_col]
|
||||
her2 = row[her2_col]
|
||||
|
||||
if her2:
|
||||
if er or pr:
|
||||
return 'HR+/HER2+ (Luminal B HER2+)'
|
||||
else:
|
||||
return 'HR-/HER2+ (HER2-enriched)'
|
||||
elif er or pr:
|
||||
return 'HR+/HER2- (Luminal)'
|
||||
else:
|
||||
return 'Triple-Negative'
|
||||
|
||||
data['bc_subtype'] = data.apply(get_subtype, axis=1)
|
||||
|
||||
print("\nBreast Cancer Subtype Distribution:")
|
||||
print(data['bc_subtype'].value_counts())
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def correlate_biomarker_outcome(data, biomarker_col, outcome_col, biomarker_type='binary'):
|
||||
"""
|
||||
Assess correlation between biomarker and clinical outcome.
|
||||
|
||||
Parameters:
|
||||
biomarker_col: Biomarker variable
|
||||
outcome_col: Outcome variable
|
||||
biomarker_type: 'binary', 'categorical', 'continuous'
|
||||
|
||||
Returns:
|
||||
Statistical test results
|
||||
"""
|
||||
|
||||
print(f"\nCorrelation Analysis: {biomarker_col} vs {outcome_col}")
|
||||
print("="*60)
|
||||
|
||||
# Remove missing data
|
||||
analysis_data = data[[biomarker_col, outcome_col]].dropna()
|
||||
|
||||
if biomarker_type == 'binary' or biomarker_type == 'categorical':
|
||||
# Cross-tabulation
|
||||
contingency = pd.crosstab(analysis_data[biomarker_col], analysis_data[outcome_col])
|
||||
print("\nContingency Table:")
|
||||
print(contingency)
|
||||
|
||||
# Chi-square test
|
||||
chi2, p_value, dof, expected = stats.chi2_contingency(contingency)
|
||||
|
||||
print(f"\nChi-square test:")
|
||||
print(f" χ² = {chi2:.2f}, df = {dof}, p = {p_value:.4f}")
|
||||
|
||||
# Odds ratio if 2x2 table
|
||||
if contingency.shape == (2, 2):
|
||||
a, b = contingency.iloc[0, :]
|
||||
c, d = contingency.iloc[1, :]
|
||||
or_value = (a * d) / (b * c) if b * c > 0 else np.inf
|
||||
|
||||
# Confidence interval for OR (log method)
|
||||
log_or = np.log(or_value)
|
||||
se_log_or = np.sqrt(1/a + 1/b + 1/c + 1/d)
|
||||
ci_lower = np.exp(log_or - 1.96 * se_log_or)
|
||||
ci_upper = np.exp(log_or + 1.96 * se_log_or)
|
||||
|
||||
print(f"\nOdds Ratio: {or_value:.2f} (95% CI {ci_lower:.2f}-{ci_upper:.2f})")
|
||||
|
||||
elif biomarker_type == 'continuous':
|
||||
# Correlation coefficient
|
||||
r, p_value = stats.pearsonr(analysis_data[biomarker_col], analysis_data[outcome_col])
|
||||
|
||||
print(f"\nPearson correlation:")
|
||||
print(f" r = {r:.3f}, p = {p_value:.4f}")
|
||||
|
||||
# Also report Spearman for robustness
|
||||
rho, p_spearman = stats.spearmanr(analysis_data[biomarker_col], analysis_data[outcome_col])
|
||||
print(f"Spearman correlation:")
|
||||
print(f" ρ = {rho:.3f}, p = {p_spearman:.4f}")
|
||||
|
||||
return p_value
|
||||
|
||||
|
||||
def stratify_cohort_report(data, stratification_var, output_dir='stratification_report'):
|
||||
"""
|
||||
Generate comprehensive stratification report.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame with patient data
|
||||
stratification_var: Column name for stratification
|
||||
output_dir: Output directory for reports
|
||||
"""
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\nCOHORT STRATIFICATION REPORT")
|
||||
print("="*60)
|
||||
print(f"Stratification Variable: {stratification_var}")
|
||||
print(f"Total Patients: {len(data)}")
|
||||
|
||||
# Group distribution
|
||||
distribution = data[stratification_var].value_counts()
|
||||
print(f"\nGroup Distribution:")
|
||||
for group, count in distribution.items():
|
||||
pct = count / len(data) * 100
|
||||
print(f" {group}: {count} ({pct:.1f}%)")
|
||||
|
||||
# Save distribution
|
||||
distribution.to_csv(output_dir / 'group_distribution.csv')
|
||||
|
||||
# Compare baseline characteristics across groups
|
||||
print(f"\nBaseline Characteristics by {stratification_var}:")
|
||||
|
||||
results = []
|
||||
|
||||
# Continuous variables
|
||||
continuous_vars = data.select_dtypes(include=[np.number]).columns.tolist()
|
||||
continuous_vars = [v for v in continuous_vars if v != stratification_var]
|
||||
|
||||
for var in continuous_vars[:5]: # Limit to first 5 for demo
|
||||
print(f"\n{var}:")
|
||||
for group in distribution.index:
|
||||
group_data = data[data[stratification_var] == group][var].dropna()
|
||||
print(f" {group}: median {group_data.median():.1f} [IQR {group_data.quantile(0.25):.1f}-{group_data.quantile(0.75):.1f}]")
|
||||
|
||||
# Statistical test
|
||||
if len(distribution) == 2:
|
||||
groups_list = distribution.index.tolist()
|
||||
g1 = data[data[stratification_var] == groups_list[0]][var].dropna()
|
||||
g2 = data[data[stratification_var] == groups_list[1]][var].dropna()
|
||||
_, p_value = stats.mannwhitneyu(g1, g2, alternative='two-sided')
|
||||
print(f" p-value: {p_value:.4f}")
|
||||
|
||||
results.append({
|
||||
'Variable': var,
|
||||
'Test': 'Mann-Whitney U',
|
||||
'p_value': p_value,
|
||||
'Significant': 'Yes' if p_value < 0.05 else 'No'
|
||||
})
|
||||
|
||||
# Save results
|
||||
if results:
|
||||
df_results = pd.DataFrame(results)
|
||||
df_results.to_csv(output_dir / 'statistical_comparisons.csv', index=False)
|
||||
print(f"\nStatistical comparison results saved to: {output_dir}/statistical_comparisons.csv")
|
||||
|
||||
print(f"\nStratification report complete! Files saved to {output_dir}/")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Biomarker-based patient classification')
|
||||
parser.add_argument('input_file', type=str, nargs='?', default=None,
|
||||
help='CSV file with patient and biomarker data')
|
||||
parser.add_argument('-b', '--biomarker', type=str, default=None,
|
||||
help='Biomarker column name for stratification')
|
||||
parser.add_argument('-t', '--threshold', type=float, default=None,
|
||||
help='Threshold for binary classification')
|
||||
parser.add_argument('-o', '--output-dir', type=str, default='stratification',
|
||||
help='Output directory')
|
||||
parser.add_argument('--example', action='store_true',
|
||||
help='Run with example data')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Example data if requested
|
||||
if args.example or args.input_file is None:
|
||||
print("Generating example dataset...")
|
||||
np.random.seed(42)
|
||||
n = 80
|
||||
|
||||
data = pd.DataFrame({
|
||||
'patient_id': [f'PT{i:03d}' for i in range(1, n+1)],
|
||||
'age': np.random.normal(62, 10, n),
|
||||
'sex': np.random.choice(['Male', 'Female'], n),
|
||||
'pd_l1_tps': np.random.exponential(20, n), # Exponential distribution for PD-L1
|
||||
'tmb': np.random.exponential(8, n), # Mutations per Mb
|
||||
'her2_ihc': np.random.choice(['0', '1+', '2+', '3+'], n, p=[0.6, 0.2, 0.15, 0.05]),
|
||||
'response': np.random.choice(['Yes', 'No'], n, p=[0.4, 0.6]),
|
||||
})
|
||||
|
||||
# Simulate correlation: higher PD-L1 -> better response
|
||||
data.loc[data['pd_l1_tps'] >= 50, 'response'] = np.random.choice(['Yes', 'No'],
|
||||
(data['pd_l1_tps'] >= 50).sum(),
|
||||
p=[0.65, 0.35])
|
||||
else:
|
||||
print(f"Loading data from {args.input_file}...")
|
||||
data = pd.read_csv(args.input_file)
|
||||
|
||||
print(f"Dataset: {len(data)} patients")
|
||||
print(f"Columns: {list(data.columns)}")
|
||||
|
||||
# PD-L1 classification example
|
||||
if 'pd_l1_tps' in data.columns or args.biomarker == 'pd_l1_tps':
|
||||
data = classify_pd_l1_tps(data, 'pd_l1_tps')
|
||||
|
||||
# Correlate with response if available
|
||||
if 'response' in data.columns:
|
||||
correlate_biomarker_outcome(data, 'pd_l1_category', 'response', biomarker_type='categorical')
|
||||
|
||||
# HER2 classification if columns present
|
||||
if 'her2_ihc' in data.columns:
|
||||
if 'her2_fish' not in data.columns:
|
||||
# Add placeholder FISH for IHC 2+
|
||||
data['her2_fish'] = np.nan
|
||||
data = classify_her2_status(data, 'her2_ihc', 'her2_fish')
|
||||
|
||||
# Generic binary classification if threshold provided
|
||||
if args.biomarker and args.threshold is not None:
|
||||
print(f"\nBinary classification: {args.biomarker} with threshold {args.threshold}")
|
||||
data = classify_binary_biomarker(data, args.biomarker, args.threshold)
|
||||
print(data['biomarker_class'].value_counts())
|
||||
|
||||
# Generate stratification report
|
||||
if args.biomarker:
|
||||
stratify_cohort_report(data, args.biomarker, output_dir=args.output_dir)
|
||||
elif 'pd_l1_category' in data.columns:
|
||||
stratify_cohort_report(data, 'pd_l1_category', output_dir=args.output_dir)
|
||||
|
||||
# Save classified data
|
||||
output_path = Path(args.output_dir) / 'classified_data.csv'
|
||||
data.to_csv(output_path, index=False)
|
||||
print(f"\nClassified data saved to: {output_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
# Example usage:
|
||||
# python biomarker_classifier.py data.csv -b pd_l1_tps -t 50 -o classification/
|
||||
# python biomarker_classifier.py --example
|
||||
#
|
||||
# Input CSV format:
|
||||
# patient_id,pd_l1_tps,tmb,her2_ihc,response,pfs_months,event
|
||||
# PT001,55.5,12.3,1+,Yes,14.2,1
|
||||
# PT002,8.2,5.1,0,No,6.5,1
|
||||
# ...
|
||||
|
||||
447
skills/clinical-decision-support/scripts/build_decision_tree.py
Executable file
447
skills/clinical-decision-support/scripts/build_decision_tree.py
Executable file
@@ -0,0 +1,447 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Build Clinical Decision Tree Flowcharts in TikZ Format
|
||||
|
||||
Generates LaTeX/TikZ code for clinical decision algorithms from
|
||||
simple text or YAML descriptions.
|
||||
|
||||
Dependencies: pyyaml (optional, for YAML input)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
|
||||
class DecisionNode:
|
||||
"""Represents a decision point in the clinical algorithm."""
|
||||
|
||||
def __init__(self, question, yes_path=None, no_path=None, node_id=None):
|
||||
self.question = question
|
||||
self.yes_path = yes_path
|
||||
self.no_path = no_path
|
||||
self.node_id = node_id or self._generate_id(question)
|
||||
|
||||
def _generate_id(self, text):
|
||||
"""Generate clean node ID from text."""
|
||||
return ''.join(c for c in text if c.isalnum())[:15].lower()
|
||||
|
||||
|
||||
class ActionNode:
|
||||
"""Represents an action/outcome in the clinical algorithm."""
|
||||
|
||||
def __init__(self, action, urgency='routine', node_id=None):
|
||||
self.action = action
|
||||
self.urgency = urgency # 'urgent', 'semiurgent', 'routine'
|
||||
self.node_id = node_id or self._generate_id(action)
|
||||
|
||||
def _generate_id(self, text):
|
||||
return ''.join(c for c in text if c.isalnum())[:15].lower()
|
||||
|
||||
|
||||
def generate_tikz_header():
|
||||
"""Generate TikZ preamble with style definitions."""
|
||||
|
||||
tikz = """\\documentclass[10pt]{article}
|
||||
\\usepackage[margin=0.5in, landscape]{geometry}
|
||||
\\usepackage{tikz}
|
||||
\\usetikzlibrary{shapes,arrows,positioning}
|
||||
\\usepackage{xcolor}
|
||||
|
||||
% Color definitions
|
||||
\\definecolor{urgentred}{RGB}{220,20,60}
|
||||
\\definecolor{actiongreen}{RGB}{0,153,76}
|
||||
\\definecolor{decisionyellow}{RGB}{255,193,7}
|
||||
\\definecolor{routineblue}{RGB}{100,181,246}
|
||||
\\definecolor{headerblue}{RGB}{0,102,204}
|
||||
|
||||
% TikZ styles
|
||||
\\tikzstyle{startstop} = [rectangle, rounded corners=8pt, minimum width=3cm, minimum height=1cm,
|
||||
text centered, draw=black, fill=headerblue!20, font=\\small\\bfseries]
|
||||
\\tikzstyle{decision} = [diamond, minimum width=3cm, minimum height=1.2cm, text centered,
|
||||
draw=black, fill=decisionyellow!40, font=\\small, aspect=2, inner sep=0pt,
|
||||
text width=3.5cm]
|
||||
\\tikzstyle{process} = [rectangle, rounded corners=4pt, minimum width=3.5cm, minimum height=0.9cm,
|
||||
text centered, draw=black, fill=actiongreen!20, font=\\small]
|
||||
\\tikzstyle{urgent} = [rectangle, rounded corners=4pt, minimum width=3.5cm, minimum height=0.9cm,
|
||||
text centered, draw=urgentred, line width=1.5pt, fill=urgentred!15,
|
||||
font=\\small\\bfseries]
|
||||
\\tikzstyle{routine} = [rectangle, rounded corners=4pt, minimum width=3.5cm, minimum height=0.9cm,
|
||||
text centered, draw=black, fill=routineblue!20, font=\\small]
|
||||
\\tikzstyle{arrow} = [thick,->,>=stealth]
|
||||
\\tikzstyle{urgentarrow} = [ultra thick,->,>=stealth,color=urgentred]
|
||||
|
||||
\\begin{document}
|
||||
|
||||
\\begin{center}
|
||||
{\\Large\\bfseries Clinical Decision Algorithm}\\\\[10pt]
|
||||
{\\large [TITLE TO BE SPECIFIED]}
|
||||
\\end{center}
|
||||
|
||||
\\vspace{10pt}
|
||||
|
||||
\\begin{tikzpicture}[node distance=2.2cm and 3.5cm, auto]
|
||||
|
||||
"""
|
||||
|
||||
return tikz
|
||||
|
||||
|
||||
def generate_tikz_footer():
|
||||
"""Generate TikZ closing code."""
|
||||
|
||||
tikz = """
|
||||
\\end{tikzpicture}
|
||||
|
||||
\\end{document}
|
||||
"""
|
||||
|
||||
return tikz
|
||||
|
||||
|
||||
def simple_algorithm_to_tikz(algorithm_text, output_file='algorithm.tex'):
|
||||
"""
|
||||
Convert simple text-based algorithm to TikZ flowchart.
|
||||
|
||||
Input format (simple question-action pairs):
|
||||
START: Chief complaint
|
||||
Q1: High-risk criteria present? -> YES: Immediate action (URGENT) | NO: Continue
|
||||
Q2: Risk score >= 3? -> YES: Admit ICU | NO: Outpatient management (ROUTINE)
|
||||
END: Final outcome
|
||||
|
||||
Parameters:
|
||||
algorithm_text: Multi-line string with algorithm
|
||||
output_file: Path to save .tex file
|
||||
"""
|
||||
|
||||
tikz_code = generate_tikz_header()
|
||||
|
||||
# Parse algorithm text
|
||||
lines = [line.strip() for line in algorithm_text.strip().split('\n') if line.strip()]
|
||||
|
||||
node_defs = []
|
||||
arrow_defs = []
|
||||
|
||||
previous_node = None
|
||||
node_counter = 0
|
||||
|
||||
for line in lines:
|
||||
if line.startswith('START:'):
|
||||
# Start node
|
||||
text = line.replace('START:', '').strip()
|
||||
node_id = 'start'
|
||||
node_defs.append(f"\\node [startstop] ({node_id}) {{{text}}};")
|
||||
previous_node = node_id
|
||||
node_counter += 1
|
||||
|
||||
elif line.startswith('END:'):
|
||||
# End node
|
||||
text = line.replace('END:', '').strip()
|
||||
node_id = 'end'
|
||||
|
||||
# Position relative to previous
|
||||
if previous_node:
|
||||
node_defs.append(f"\\node [startstop, below=of {previous_node}] ({node_id}) {{{text}}};")
|
||||
arrow_defs.append(f"\\draw [arrow] ({previous_node}) -- ({node_id});")
|
||||
|
||||
elif line.startswith('Q'):
|
||||
# Decision node
|
||||
parts = line.split(':', 1)
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
|
||||
question_part = parts[1].split('->')[0].strip()
|
||||
node_id = f'q{node_counter}'
|
||||
|
||||
# Add decision node
|
||||
if previous_node:
|
||||
node_defs.append(f"\\node [decision, below=of {previous_node}] ({node_id}) {{{question_part}}};")
|
||||
arrow_defs.append(f"\\draw [arrow] ({previous_node}) -- ({node_id});")
|
||||
else:
|
||||
node_defs.append(f"\\node [decision] ({node_id}) {{{question_part}}};")
|
||||
|
||||
# Parse YES and NO branches
|
||||
if '->' in line:
|
||||
branches = line.split('->')[1].split('|')
|
||||
|
||||
for branch in branches:
|
||||
branch = branch.strip()
|
||||
|
||||
if branch.startswith('YES:'):
|
||||
yes_action = branch.replace('YES:', '').strip()
|
||||
yes_id = f'yes{node_counter}'
|
||||
|
||||
# Check urgency
|
||||
if '(URGENT)' in yes_action:
|
||||
style = 'urgent'
|
||||
yes_action = yes_action.replace('(URGENT)', '').strip()
|
||||
arrow_style = 'urgentarrow'
|
||||
elif '(ROUTINE)' in yes_action:
|
||||
style = 'routine'
|
||||
yes_action = yes_action.replace('(ROUTINE)', '').strip()
|
||||
arrow_style = 'arrow'
|
||||
else:
|
||||
style = 'process'
|
||||
arrow_style = 'arrow'
|
||||
|
||||
node_defs.append(f"\\node [{style}, left=of {node_id}] ({yes_id}) {{{yes_action}}};")
|
||||
arrow_defs.append(f"\\draw [{arrow_style}] ({node_id}) -- node[above] {{Yes}} ({yes_id});")
|
||||
|
||||
elif branch.startswith('NO:'):
|
||||
no_action = branch.replace('NO:', '').strip()
|
||||
no_id = f'no{node_counter}'
|
||||
|
||||
# Check urgency
|
||||
if '(URGENT)' in no_action:
|
||||
style = 'urgent'
|
||||
no_action = no_action.replace('(URGENT)', '').strip()
|
||||
arrow_style = 'urgentarrow'
|
||||
elif '(ROUTINE)' in no_action:
|
||||
style = 'routine'
|
||||
no_action = no_action.replace('(ROUTINE)', '').strip()
|
||||
arrow_style = 'arrow'
|
||||
else:
|
||||
style = 'process'
|
||||
arrow_style = 'arrow'
|
||||
|
||||
node_defs.append(f"\\node [{style}, right=of {node_id}] ({no_id}) {{{no_action}}};")
|
||||
arrow_defs.append(f"\\draw [{arrow_style}] ({node_id}) -- node[above] {{No}} ({no_id});")
|
||||
|
||||
previous_node = node_id
|
||||
node_counter += 1
|
||||
|
||||
# Add all nodes and arrows to TikZ
|
||||
tikz_code += '\n'.join(node_defs) + '\n\n'
|
||||
tikz_code += '% Arrows\n'
|
||||
tikz_code += '\n'.join(arrow_defs) + '\n'
|
||||
|
||||
tikz_code += generate_tikz_footer()
|
||||
|
||||
# Save to file
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(tikz_code)
|
||||
|
||||
print(f"TikZ flowchart saved to: {output_file}")
|
||||
print(f"Compile with: pdflatex {output_file}")
|
||||
|
||||
return tikz_code
|
||||
|
||||
|
||||
def json_to_tikz(json_file, output_file='algorithm.tex'):
|
||||
"""
|
||||
Convert JSON decision tree specification to TikZ flowchart.
|
||||
|
||||
JSON format:
|
||||
{
|
||||
"title": "Algorithm Title",
|
||||
"nodes": {
|
||||
"start": {"type": "start", "text": "Patient presentation"},
|
||||
"q1": {"type": "decision", "text": "Criteria met?", "yes": "action1", "no": "q2"},
|
||||
"action1": {"type": "action", "text": "Immediate intervention", "urgency": "urgent"},
|
||||
"q2": {"type": "decision", "text": "Score >= 3?", "yes": "action2", "no": "action3"},
|
||||
"action2": {"type": "action", "text": "Admit ICU"},
|
||||
"action3": {"type": "action", "text": "Outpatient", "urgency": "routine"}
|
||||
},
|
||||
"start_node": "start"
|
||||
}
|
||||
"""
|
||||
|
||||
with open(json_file, 'r') as f:
|
||||
spec = json.load(f)
|
||||
|
||||
tikz_code = generate_tikz_header()
|
||||
|
||||
# Replace title
|
||||
title = spec.get('title', 'Clinical Decision Algorithm')
|
||||
tikz_code = tikz_code.replace('[TITLE TO BE SPECIFIED]', title)
|
||||
|
||||
nodes = spec['nodes']
|
||||
start_node = spec.get('start_node', 'start')
|
||||
|
||||
# Generate nodes (simplified layout - vertical)
|
||||
node_defs = []
|
||||
arrow_defs = []
|
||||
|
||||
# Track positioning
|
||||
previous_node = None
|
||||
level = 0
|
||||
|
||||
def add_node(node_id, position_rel=None):
|
||||
"""Recursively add nodes."""
|
||||
|
||||
if node_id not in nodes:
|
||||
return
|
||||
|
||||
node = nodes[node_id]
|
||||
node_type = node['type']
|
||||
text = node['text']
|
||||
|
||||
# Determine TikZ style
|
||||
if node_type == 'start' or node_type == 'end':
|
||||
style = 'startstop'
|
||||
elif node_type == 'decision':
|
||||
style = 'decision'
|
||||
elif node_type == 'action':
|
||||
urgency = node.get('urgency', 'normal')
|
||||
if urgency == 'urgent':
|
||||
style = 'urgent'
|
||||
elif urgency == 'routine':
|
||||
style = 'routine'
|
||||
else:
|
||||
style = 'process'
|
||||
else:
|
||||
style = 'process'
|
||||
|
||||
# Position node
|
||||
if position_rel:
|
||||
node_def = f"\\node [{style}, {position_rel}] ({node_id}) {{{text}}};"
|
||||
else:
|
||||
node_def = f"\\node [{style}] ({node_id}) {{{text}}};"
|
||||
|
||||
node_defs.append(node_def)
|
||||
|
||||
# Add arrows for decision nodes
|
||||
if node_type == 'decision':
|
||||
yes_target = node.get('yes')
|
||||
no_target = node.get('no')
|
||||
|
||||
if yes_target:
|
||||
# Determine arrow style based on target urgency
|
||||
target_node = nodes.get(yes_target, {})
|
||||
arrow_style = 'urgentarrow' if target_node.get('urgency') == 'urgent' else 'arrow'
|
||||
arrow_defs.append(f"\\draw [{arrow_style}] ({node_id}) -| node[near start, above] {{Yes}} ({yes_target});")
|
||||
|
||||
if no_target:
|
||||
target_node = nodes.get(no_target, {})
|
||||
arrow_style = 'urgentarrow' if target_node.get('urgency') == 'urgent' else 'arrow'
|
||||
arrow_defs.append(f"\\draw [{arrow_style}] ({node_id}) -| node[near start, above] {{No}} ({no_target});")
|
||||
|
||||
# Simple layout - just list nodes (manual positioning in JSON works better for complex trees)
|
||||
for node_id in nodes.keys():
|
||||
add_node(node_id)
|
||||
|
||||
tikz_code += '\n'.join(node_defs) + '\n\n'
|
||||
tikz_code += '% Arrows\n'
|
||||
tikz_code += '\n'.join(arrow_defs) + '\n'
|
||||
|
||||
tikz_code += generate_tikz_footer()
|
||||
|
||||
# Save
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(tikz_code)
|
||||
|
||||
print(f"TikZ flowchart saved to: {output_file}")
|
||||
return tikz_code
|
||||
|
||||
|
||||
def create_example_json():
|
||||
"""Create example JSON specification for testing."""
|
||||
|
||||
example = {
|
||||
"title": "Acute Chest Pain Management Algorithm",
|
||||
"nodes": {
|
||||
"start": {
|
||||
"type": "start",
|
||||
"text": "Patient with\\nchest pain"
|
||||
},
|
||||
"q1": {
|
||||
"type": "decision",
|
||||
"text": "STEMI\\ncriteria?",
|
||||
"yes": "stemi_action",
|
||||
"no": "q2"
|
||||
},
|
||||
"stemi_action": {
|
||||
"type": "action",
|
||||
"text": "Activate cath lab\\nAspirin, heparin\\nPrimary PCI",
|
||||
"urgency": "urgent"
|
||||
},
|
||||
"q2": {
|
||||
"type": "decision",
|
||||
"text": "High-risk\\nfeatures?",
|
||||
"yes": "admit",
|
||||
"no": "q3"
|
||||
},
|
||||
"admit": {
|
||||
"type": "action",
|
||||
"text": "Admit CCU\\nSerial troponins\\nEarly angiography"
|
||||
},
|
||||
"q3": {
|
||||
"type": "decision",
|
||||
"text": "TIMI\\nscore 0-1?",
|
||||
"yes": "lowrisk",
|
||||
"no": "moderate"
|
||||
},
|
||||
"lowrisk": {
|
||||
"type": "action",
|
||||
"text": "Observe 6-12h\\nStress test\\nOutpatient f/u",
|
||||
"urgency": "routine"
|
||||
},
|
||||
"moderate": {
|
||||
"type": "action",
|
||||
"text": "Admit telemetry\\nMedical management\\nRisk stratification"
|
||||
}
|
||||
},
|
||||
"start_node": "start"
|
||||
}
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Build clinical decision tree flowcharts')
|
||||
parser.add_argument('-i', '--input', type=str, default=None,
|
||||
help='Input file (JSON format)')
|
||||
parser.add_argument('-o', '--output', type=str, default='clinical_algorithm.tex',
|
||||
help='Output .tex file')
|
||||
parser.add_argument('--example', action='store_true',
|
||||
help='Generate example algorithm')
|
||||
parser.add_argument('--text', type=str, default=None,
|
||||
help='Simple text algorithm (see format in docstring)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.example:
|
||||
print("Generating example algorithm...")
|
||||
example_spec = create_example_json()
|
||||
|
||||
# Save example JSON
|
||||
with open('example_algorithm.json', 'w') as f:
|
||||
json.dump(example_spec, f, indent=2)
|
||||
print("Example JSON saved to: example_algorithm.json")
|
||||
|
||||
# Generate TikZ from example
|
||||
json_to_tikz('example_algorithm.json', args.output)
|
||||
|
||||
elif args.text:
|
||||
print("Generating algorithm from text...")
|
||||
simple_algorithm_to_tikz(args.text, args.output)
|
||||
|
||||
elif args.input:
|
||||
print(f"Generating algorithm from {args.input}...")
|
||||
if args.input.endswith('.json'):
|
||||
json_to_tikz(args.input, args.output)
|
||||
else:
|
||||
with open(args.input, 'r') as f:
|
||||
text = f.read()
|
||||
simple_algorithm_to_tikz(text, args.output)
|
||||
|
||||
else:
|
||||
print("No input provided. Use --example to generate example, --text for simple text, or -i for JSON input.")
|
||||
print("\nSimple text format:")
|
||||
print("START: Patient presentation")
|
||||
print("Q1: Criteria met? -> YES: Action (URGENT) | NO: Continue")
|
||||
print("Q2: Score >= 3? -> YES: Admit | NO: Outpatient (ROUTINE)")
|
||||
print("END: Follow-up")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
# Example usage:
|
||||
# python build_decision_tree.py --example
|
||||
# python build_decision_tree.py -i algorithm_spec.json -o my_algorithm.tex
|
||||
#
|
||||
# Then compile:
|
||||
# pdflatex clinical_algorithm.tex
|
||||
|
||||
524
skills/clinical-decision-support/scripts/create_cohort_tables.py
Executable file
524
skills/clinical-decision-support/scripts/create_cohort_tables.py
Executable file
@@ -0,0 +1,524 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate Clinical Cohort Tables for Baseline Characteristics and Outcomes
|
||||
|
||||
Creates publication-ready tables with:
|
||||
- Baseline demographics (Table 1 style)
|
||||
- Efficacy outcomes
|
||||
- Safety/adverse events
|
||||
- Statistical comparisons between groups
|
||||
|
||||
Dependencies: pandas, numpy, scipy
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
def calculate_p_value(data, variable, group_col='group', var_type='categorical'):
|
||||
"""
|
||||
Calculate appropriate p-value for group comparison.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame
|
||||
variable: Column name to compare
|
||||
group_col: Grouping variable
|
||||
var_type: 'categorical', 'continuous_normal', 'continuous_nonnormal'
|
||||
|
||||
Returns:
|
||||
p-value (float)
|
||||
"""
|
||||
|
||||
groups = data[group_col].unique()
|
||||
|
||||
if len(groups) != 2:
|
||||
return np.nan # Only handle 2-group comparisons
|
||||
|
||||
group1_data = data[data[group_col] == groups[0]][variable].dropna()
|
||||
group2_data = data[data[group_col] == groups[1]][variable].dropna()
|
||||
|
||||
if var_type == 'categorical':
|
||||
# Chi-square or Fisher's exact test
|
||||
contingency = pd.crosstab(data[variable], data[group_col])
|
||||
|
||||
# Check if Fisher's exact is needed (expected count < 5)
|
||||
if contingency.min().min() < 5:
|
||||
# Fisher's exact (2x2 only)
|
||||
if contingency.shape == (2, 2):
|
||||
_, p_value = stats.fisher_exact(contingency)
|
||||
else:
|
||||
# Use chi-square but note limitation
|
||||
_, p_value, _, _ = stats.chi2_contingency(contingency)
|
||||
else:
|
||||
_, p_value, _, _ = stats.chi2_contingency(contingency)
|
||||
|
||||
elif var_type == 'continuous_normal':
|
||||
# Independent t-test
|
||||
_, p_value = stats.ttest_ind(group1_data, group2_data, equal_var=False)
|
||||
|
||||
elif var_type == 'continuous_nonnormal':
|
||||
# Mann-Whitney U test
|
||||
_, p_value = stats.mannwhitneyu(group1_data, group2_data, alternative='two-sided')
|
||||
|
||||
else:
|
||||
raise ValueError("var_type must be 'categorical', 'continuous_normal', or 'continuous_nonnormal'")
|
||||
|
||||
return p_value
|
||||
|
||||
|
||||
def format_continuous_variable(data, variable, group_col, distribution='normal'):
|
||||
"""
|
||||
Format continuous variable for table display.
|
||||
|
||||
Returns:
|
||||
Dictionary with formatted strings for each group and p-value
|
||||
"""
|
||||
|
||||
groups = data[group_col].unique()
|
||||
results = {}
|
||||
|
||||
for group in groups:
|
||||
group_data = data[data[group_col] == group][variable].dropna()
|
||||
|
||||
if distribution == 'normal':
|
||||
# Mean ± SD
|
||||
mean = group_data.mean()
|
||||
std = group_data.std()
|
||||
results[group] = f"{mean:.1f} ± {std:.1f}"
|
||||
else:
|
||||
# Median [IQR]
|
||||
median = group_data.median()
|
||||
q1 = group_data.quantile(0.25)
|
||||
q3 = group_data.quantile(0.75)
|
||||
results[group] = f"{median:.1f} [{q1:.1f}-{q3:.1f}]"
|
||||
|
||||
# Calculate p-value
|
||||
var_type = 'continuous_normal' if distribution == 'normal' else 'continuous_nonnormal'
|
||||
p_value = calculate_p_value(data, variable, group_col, var_type)
|
||||
results['p_value'] = f"{p_value:.3f}" if p_value < 0.001 else f"{p_value:.2f}" if p_value < 1.0 else "—"
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_categorical_variable(data, variable, group_col):
|
||||
"""
|
||||
Format categorical variable for table display.
|
||||
|
||||
Returns:
|
||||
List of dictionaries for each category with counts and percentages
|
||||
"""
|
||||
|
||||
groups = data[group_col].unique()
|
||||
categories = data[variable].dropna().unique()
|
||||
|
||||
results = []
|
||||
|
||||
for category in categories:
|
||||
row = {'category': category}
|
||||
|
||||
for group in groups:
|
||||
group_data = data[data[group_col] == group]
|
||||
count = (group_data[variable] == category).sum()
|
||||
total = group_data[variable].notna().sum()
|
||||
percentage = (count / total * 100) if total > 0 else 0
|
||||
row[group] = f"{count} ({percentage:.0f}%)"
|
||||
|
||||
results.append(row)
|
||||
|
||||
# Calculate p-value for overall categorical variable
|
||||
p_value = calculate_p_value(data, variable, group_col, 'categorical')
|
||||
results[0]['p_value'] = f"{p_value:.3f}" if p_value < 0.001 else f"{p_value:.2f}" if p_value < 1.0 else "—"
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def generate_baseline_table(data, group_col='group', output_file='table1_baseline.csv'):
|
||||
"""
|
||||
Generate Table 1: Baseline characteristics.
|
||||
|
||||
Customize the variables list for your specific cohort.
|
||||
"""
|
||||
|
||||
groups = data[group_col].unique()
|
||||
|
||||
# Initialize results list
|
||||
table_rows = []
|
||||
|
||||
# Header row
|
||||
header = {
|
||||
'Characteristic': 'Characteristic',
|
||||
**{group: f"{group} (n={len(data[data[group_col]==group])})" for group in groups},
|
||||
'p_value': 'p-value'
|
||||
}
|
||||
table_rows.append(header)
|
||||
|
||||
# Age (continuous)
|
||||
if 'age' in data.columns:
|
||||
age_results = format_continuous_variable(data, 'age', group_col, distribution='nonnormal')
|
||||
row = {'Characteristic': 'Age, years (median [IQR])'}
|
||||
for group in groups:
|
||||
row[group] = age_results[group]
|
||||
row['p_value'] = age_results['p_value']
|
||||
table_rows.append(row)
|
||||
|
||||
# Sex (categorical)
|
||||
if 'sex' in data.columns:
|
||||
table_rows.append({'Characteristic': 'Sex, n (%)', **{g: '' for g in groups}, 'p_value': ''})
|
||||
sex_results = format_categorical_variable(data, 'sex', group_col)
|
||||
for sex_row in sex_results:
|
||||
row = {'Characteristic': f" {sex_row['category']}"}
|
||||
for group in groups:
|
||||
row[group] = sex_row[group]
|
||||
row['p_value'] = sex_row.get('p_value', '')
|
||||
table_rows.append(row)
|
||||
|
||||
# ECOG Performance Status (categorical)
|
||||
if 'ecog_ps' in data.columns:
|
||||
table_rows.append({'Characteristic': 'ECOG PS, n (%)', **{g: '' for g in groups}, 'p_value': ''})
|
||||
ecog_results = format_categorical_variable(data, 'ecog_ps', group_col)
|
||||
for ecog_row in ecog_results:
|
||||
row = {'Characteristic': f" {ecog_row['category']}"}
|
||||
for group in groups:
|
||||
row[group] = ecog_row[group]
|
||||
row['p_value'] = ecog_row.get('p_value', '')
|
||||
table_rows.append(row)
|
||||
|
||||
# Convert to DataFrame and save
|
||||
df_table = pd.DataFrame(table_rows)
|
||||
df_table.to_csv(output_file, index=False)
|
||||
print(f"Baseline characteristics table saved to: {output_file}")
|
||||
|
||||
return df_table
|
||||
|
||||
|
||||
def generate_efficacy_table(data, group_col='group', output_file='table2_efficacy.csv'):
|
||||
"""
|
||||
Generate efficacy outcomes table.
|
||||
|
||||
Expected columns:
|
||||
- best_response: CR, PR, SD, PD
|
||||
- Additional binary outcomes (response, disease_control, etc.)
|
||||
"""
|
||||
|
||||
groups = data[group_col].unique()
|
||||
table_rows = []
|
||||
|
||||
# Header
|
||||
header = {
|
||||
'Outcome': 'Outcome',
|
||||
**{group: f"{group} (n={len(data[data[group_col]==group])})" for group in groups},
|
||||
'p_value': 'p-value'
|
||||
}
|
||||
table_rows.append(header)
|
||||
|
||||
# Objective Response Rate (ORR = CR + PR)
|
||||
if 'best_response' in data.columns:
|
||||
for group in groups:
|
||||
group_data = data[data[group_col] == group]
|
||||
cr_pr = ((group_data['best_response'] == 'CR') | (group_data['best_response'] == 'PR')).sum()
|
||||
total = len(group_data)
|
||||
orr = cr_pr / total * 100
|
||||
|
||||
# Calculate exact binomial CI (Clopper-Pearson)
|
||||
ci_lower, ci_upper = _binomial_ci(cr_pr, total)
|
||||
|
||||
if group == groups[0]:
|
||||
orr_row = {'Outcome': 'ORR, n (%) [95% CI]'}
|
||||
|
||||
orr_row[group] = f"{cr_pr} ({orr:.0f}%) [{ci_lower:.0f}-{ci_upper:.0f}]"
|
||||
|
||||
# P-value for ORR difference
|
||||
contingency = pd.crosstab(
|
||||
data['best_response'].isin(['CR', 'PR']),
|
||||
data[group_col]
|
||||
)
|
||||
_, p_value, _, _ = stats.chi2_contingency(contingency)
|
||||
orr_row['p_value'] = f"{p_value:.3f}" if p_value >= 0.001 else "<0.001"
|
||||
table_rows.append(orr_row)
|
||||
|
||||
# Individual response categories
|
||||
for response in ['CR', 'PR', 'SD', 'PD']:
|
||||
row = {'Outcome': f" {response}"}
|
||||
for group in groups:
|
||||
group_data = data[data[group_col] == group]
|
||||
count = (group_data['best_response'] == response).sum()
|
||||
total = len(group_data)
|
||||
pct = count / total * 100
|
||||
row[group] = f"{count} ({pct:.0f}%)"
|
||||
row['p_value'] = ''
|
||||
table_rows.append(row)
|
||||
|
||||
# Disease Control Rate (DCR = CR + PR + SD)
|
||||
if 'best_response' in data.columns:
|
||||
dcr_row = {'Outcome': 'DCR, n (%) [95% CI]'}
|
||||
for group in groups:
|
||||
group_data = data[data[group_col] == group]
|
||||
dcr_count = group_data['best_response'].isin(['CR', 'PR', 'SD']).sum()
|
||||
total = len(group_data)
|
||||
dcr = dcr_count / total * 100
|
||||
ci_lower, ci_upper = _binomial_ci(dcr_count, total)
|
||||
dcr_row[group] = f"{dcr_count} ({dcr:.0f}%) [{ci_lower:.0f}-{ci_upper:.0f}]"
|
||||
|
||||
# P-value
|
||||
contingency = pd.crosstab(
|
||||
data['best_response'].isin(['CR', 'PR', 'SD']),
|
||||
data[group_col]
|
||||
)
|
||||
_, p_value, _, _ = stats.chi2_contingency(contingency)
|
||||
dcr_row['p_value'] = f"{p_value:.3f}" if p_value >= 0.001 else "<0.001"
|
||||
table_rows.append(dcr_row)
|
||||
|
||||
# Save table
|
||||
df_table = pd.DataFrame(table_rows)
|
||||
df_table.to_csv(output_file, index=False)
|
||||
print(f"Efficacy table saved to: {output_file}")
|
||||
|
||||
return df_table
|
||||
|
||||
|
||||
def generate_safety_table(data, ae_columns, group_col='group', output_file='table3_safety.csv'):
|
||||
"""
|
||||
Generate adverse events table.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame with AE data
|
||||
ae_columns: List of AE column names (each should have values 0-5 for CTCAE grades)
|
||||
group_col: Grouping variable
|
||||
output_file: Output CSV path
|
||||
"""
|
||||
|
||||
groups = data[group_col].unique()
|
||||
table_rows = []
|
||||
|
||||
# Header
|
||||
header = {
|
||||
'Adverse Event': 'Adverse Event',
|
||||
**{f'{group}_any': f'Any Grade' for group in groups},
|
||||
**{f'{group}_g34': f'Grade 3-4' for group in groups}
|
||||
}
|
||||
|
||||
for ae in ae_columns:
|
||||
if ae not in data.columns:
|
||||
continue
|
||||
|
||||
row = {'Adverse Event': ae.replace('_', ' ').title()}
|
||||
|
||||
for group in groups:
|
||||
group_data = data[data[group_col] == group][ae].dropna()
|
||||
total = len(group_data)
|
||||
|
||||
# Any grade (Grade 1-5)
|
||||
any_grade = (group_data > 0).sum()
|
||||
any_pct = any_grade / total * 100 if total > 0 else 0
|
||||
row[f'{group}_any'] = f"{any_grade} ({any_pct:.0f}%)"
|
||||
|
||||
# Grade 3-4
|
||||
grade_34 = (group_data >= 3).sum()
|
||||
g34_pct = grade_34 / total * 100 if total > 0 else 0
|
||||
row[f'{group}_g34'] = f"{grade_34} ({g34_pct:.0f}%)"
|
||||
|
||||
table_rows.append(row)
|
||||
|
||||
# Save table
|
||||
df_table = pd.DataFrame(table_rows)
|
||||
df_table.to_csv(output_file, index=False)
|
||||
print(f"Safety table saved to: {output_file}")
|
||||
|
||||
return df_table
|
||||
|
||||
|
||||
def generate_latex_table(df, caption, label='table'):
|
||||
"""
|
||||
Convert DataFrame to LaTeX table code.
|
||||
|
||||
Returns:
|
||||
String with LaTeX table code
|
||||
"""
|
||||
|
||||
latex_code = "\\begin{table}[H]\n"
|
||||
latex_code += "\\centering\n"
|
||||
latex_code += "\\small\n"
|
||||
latex_code += "\\begin{tabular}{" + "l" * len(df.columns) + "}\n"
|
||||
latex_code += "\\toprule\n"
|
||||
|
||||
# Header
|
||||
header_row = " & ".join([f"\\textbf{{{col}}}" for col in df.columns])
|
||||
latex_code += header_row + " \\\\\n"
|
||||
latex_code += "\\midrule\n"
|
||||
|
||||
# Data rows
|
||||
for _, row in df.iterrows():
|
||||
# Handle indentation for subcategories (lines starting with spaces)
|
||||
first_col = str(row.iloc[0])
|
||||
if first_col.startswith(' '):
|
||||
first_col = '\\quad ' + first_col.strip()
|
||||
|
||||
data_row = [first_col] + [str(val) if pd.notna(val) else '—' for val in row.iloc[1:]]
|
||||
latex_code += " & ".join(data_row) + " \\\\\n"
|
||||
|
||||
latex_code += "\\bottomrule\n"
|
||||
latex_code += "\\end{tabular}\n"
|
||||
latex_code += f"\\caption{{{caption}}}\n"
|
||||
latex_code += f"\\label{{tab:{label}}}\n"
|
||||
latex_code += "\\end{table}\n"
|
||||
|
||||
return latex_code
|
||||
|
||||
|
||||
def _binomial_ci(successes, trials, confidence=0.95):
|
||||
"""
|
||||
Calculate exact binomial confidence interval (Clopper-Pearson method).
|
||||
|
||||
Returns:
|
||||
Lower and upper bounds as percentages
|
||||
"""
|
||||
|
||||
if trials == 0:
|
||||
return 0.0, 0.0
|
||||
|
||||
alpha = 1 - confidence
|
||||
|
||||
# Use beta distribution
|
||||
from scipy.stats import beta
|
||||
|
||||
if successes == 0:
|
||||
lower = 0.0
|
||||
else:
|
||||
lower = beta.ppf(alpha/2, successes, trials - successes + 1)
|
||||
|
||||
if successes == trials:
|
||||
upper = 1.0
|
||||
else:
|
||||
upper = beta.ppf(1 - alpha/2, successes + 1, trials - successes)
|
||||
|
||||
return lower * 100, upper * 100
|
||||
|
||||
|
||||
def create_example_data():
|
||||
"""Create example dataset for testing."""
|
||||
|
||||
np.random.seed(42)
|
||||
n = 100
|
||||
|
||||
data = pd.DataFrame({
|
||||
'patient_id': [f'PT{i:03d}' for i in range(1, n+1)],
|
||||
'group': np.random.choice(['Biomarker+', 'Biomarker-'], n),
|
||||
'age': np.random.normal(62, 10, n),
|
||||
'sex': np.random.choice(['Male', 'Female'], n),
|
||||
'ecog_ps': np.random.choice(['0-1', '2'], n, p=[0.8, 0.2]),
|
||||
'stage': np.random.choice(['III', 'IV'], n, p=[0.3, 0.7]),
|
||||
'best_response': np.random.choice(['CR', 'PR', 'SD', 'PD'], n, p=[0.05, 0.35, 0.40, 0.20]),
|
||||
'fatigue_grade': np.random.choice([0, 1, 2, 3], n, p=[0.3, 0.4, 0.2, 0.1]),
|
||||
'nausea_grade': np.random.choice([0, 1, 2, 3], n, p=[0.4, 0.35, 0.20, 0.05]),
|
||||
'neutropenia_grade': np.random.choice([0, 1, 2, 3, 4], n, p=[0.5, 0.2, 0.15, 0.10, 0.05]),
|
||||
})
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate clinical cohort tables')
|
||||
parser.add_argument('input_file', type=str, nargs='?', default=None,
|
||||
help='CSV file with cohort data (if not provided, uses example data)')
|
||||
parser.add_argument('-o', '--output-dir', type=str, default='tables',
|
||||
help='Output directory (default: tables)')
|
||||
parser.add_argument('--group-col', type=str, default='group',
|
||||
help='Column name for grouping variable')
|
||||
parser.add_argument('--example', action='store_true',
|
||||
help='Generate tables using example data')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load or create data
|
||||
if args.example or args.input_file is None:
|
||||
print("Generating example dataset...")
|
||||
data = create_example_data()
|
||||
else:
|
||||
print(f"Loading data from {args.input_file}...")
|
||||
data = pd.read_csv(args.input_file)
|
||||
|
||||
print(f"Dataset: {len(data)} patients, {len(data[args.group_col].unique())} groups")
|
||||
print(f"Groups: {data[args.group_col].value_counts().to_dict()}")
|
||||
|
||||
# Generate Table 1: Baseline characteristics
|
||||
print("\nGenerating baseline characteristics table...")
|
||||
baseline_table = generate_baseline_table(
|
||||
data,
|
||||
group_col=args.group_col,
|
||||
output_file=output_dir / 'table1_baseline.csv'
|
||||
)
|
||||
|
||||
# Generate LaTeX code for baseline table
|
||||
latex_code = generate_latex_table(
|
||||
baseline_table,
|
||||
caption="Baseline patient demographics and clinical characteristics",
|
||||
label="baseline"
|
||||
)
|
||||
with open(output_dir / 'table1_baseline.tex', 'w') as f:
|
||||
f.write(latex_code)
|
||||
print(f"LaTeX code saved to: {output_dir}/table1_baseline.tex")
|
||||
|
||||
# Generate Table 2: Efficacy outcomes
|
||||
if 'best_response' in data.columns:
|
||||
print("\nGenerating efficacy outcomes table...")
|
||||
efficacy_table = generate_efficacy_table(
|
||||
data,
|
||||
group_col=args.group_col,
|
||||
output_file=output_dir / 'table2_efficacy.csv'
|
||||
)
|
||||
|
||||
latex_code = generate_latex_table(
|
||||
efficacy_table,
|
||||
caption="Treatment efficacy outcomes by group",
|
||||
label="efficacy"
|
||||
)
|
||||
with open(output_dir / 'table2_efficacy.tex', 'w') as f:
|
||||
f.write(latex_code)
|
||||
|
||||
# Generate Table 3: Safety (identify AE columns)
|
||||
ae_columns = [col for col in data.columns if col.endswith('_grade')]
|
||||
if ae_columns:
|
||||
print("\nGenerating safety table...")
|
||||
safety_table = generate_safety_table(
|
||||
data,
|
||||
ae_columns=ae_columns,
|
||||
group_col=args.group_col,
|
||||
output_file=output_dir / 'table3_safety.csv'
|
||||
)
|
||||
|
||||
latex_code = generate_latex_table(
|
||||
safety_table,
|
||||
caption="Treatment-emergent adverse events by group (CTCAE v5.0)",
|
||||
label="safety"
|
||||
)
|
||||
with open(output_dir / 'table3_safety.tex', 'w') as f:
|
||||
f.write(latex_code)
|
||||
|
||||
print(f"\nAll tables generated successfully in {output_dir}/")
|
||||
print("Files created:")
|
||||
print(" - table1_baseline.csv / .tex")
|
||||
print(" - table2_efficacy.csv / .tex (if response data available)")
|
||||
print(" - table3_safety.csv / .tex (if AE data available)")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
# Example usage:
|
||||
# python create_cohort_tables.py cohort_data.csv -o tables/
|
||||
# python create_cohort_tables.py --example # Generate example tables
|
||||
#
|
||||
# Input CSV format:
|
||||
# patient_id,group,age,sex,ecog_ps,stage,best_response,fatigue_grade,nausea_grade,...
|
||||
# PT001,Biomarker+,65,Male,0-1,IV,PR,1,0,...
|
||||
# PT002,Biomarker-,58,Female,0-1,III,SD,2,1,...
|
||||
# ...
|
||||
|
||||
422
skills/clinical-decision-support/scripts/generate_survival_analysis.py
Executable file
422
skills/clinical-decision-support/scripts/generate_survival_analysis.py
Executable file
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate Kaplan-Meier Survival Curves for Clinical Decision Support Documents
|
||||
|
||||
This script creates publication-quality survival curves with:
|
||||
- Kaplan-Meier survival estimates
|
||||
- 95% confidence intervals
|
||||
- Log-rank test statistics
|
||||
- Hazard ratios with confidence intervals
|
||||
- Number at risk tables
|
||||
- Median survival annotations
|
||||
|
||||
Dependencies: lifelines, matplotlib, pandas, numpy
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from lifelines import KaplanMeierFitter
|
||||
from lifelines.statistics import logrank_test, multivariate_logrank_test
|
||||
from lifelines import CoxPHFitter
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_survival_data(filepath):
|
||||
"""
|
||||
Load survival data from CSV file.
|
||||
|
||||
Expected columns:
|
||||
- patient_id: Unique patient identifier
|
||||
- time: Survival time (months or days)
|
||||
- event: Event indicator (1=event occurred, 0=censored)
|
||||
- group: Stratification variable (e.g., 'Biomarker+', 'Biomarker-')
|
||||
- Optional: Additional covariates for Cox regression
|
||||
|
||||
Returns:
|
||||
pandas.DataFrame
|
||||
"""
|
||||
df = pd.read_csv(filepath)
|
||||
|
||||
# Validate required columns
|
||||
required_cols = ['patient_id', 'time', 'event', 'group']
|
||||
missing = [col for col in required_cols if col not in df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"Missing required columns: {missing}")
|
||||
|
||||
# Convert event to boolean if needed
|
||||
df['event'] = df['event'].astype(bool)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def calculate_median_survival(kmf):
|
||||
"""Calculate median survival with 95% CI."""
|
||||
median = kmf.median_survival_time_
|
||||
ci = kmf.confidence_interval_survival_function_
|
||||
|
||||
# Find time when survival crosses 0.5
|
||||
if median == np.inf:
|
||||
return None, None, None
|
||||
|
||||
# Get CI at median
|
||||
idx = np.argmin(np.abs(kmf.survival_function_.index - median))
|
||||
lower_ci = ci.iloc[idx]['KM_estimate_lower_0.95']
|
||||
upper_ci = ci.iloc[idx]['KM_estimate_upper_0.95']
|
||||
|
||||
return median, lower_ci, upper_ci
|
||||
|
||||
|
||||
def generate_kaplan_meier_plot(data, time_col='time', event_col='event',
|
||||
group_col='group', output_path='survival_curve.pdf',
|
||||
title='Kaplan-Meier Survival Curve',
|
||||
xlabel='Time (months)', ylabel='Survival Probability'):
|
||||
"""
|
||||
Generate Kaplan-Meier survival curve comparing groups.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame with survival data
|
||||
time_col: Column name for survival time
|
||||
event_col: Column name for event indicator
|
||||
group_col: Column name for stratification
|
||||
output_path: Path to save figure
|
||||
title: Plot title
|
||||
xlabel: X-axis label (specify units)
|
||||
ylabel: Y-axis label
|
||||
"""
|
||||
|
||||
# Create figure and axis
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# Get unique groups
|
||||
groups = data[group_col].unique()
|
||||
|
||||
# Colors for groups (colorblind-friendly)
|
||||
colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', '#CA9161']
|
||||
|
||||
kmf_models = {}
|
||||
median_survivals = {}
|
||||
|
||||
# Plot each group
|
||||
for i, group in enumerate(groups):
|
||||
group_data = data[data[group_col] == group]
|
||||
|
||||
# Fit Kaplan-Meier
|
||||
kmf = KaplanMeierFitter()
|
||||
kmf.fit(group_data[time_col], group_data[event_col], label=str(group))
|
||||
|
||||
# Plot survival curve
|
||||
kmf.plot_survival_function(ax=ax, ci_show=True, color=colors[i % len(colors)],
|
||||
linewidth=2, alpha=0.8)
|
||||
|
||||
# Store model
|
||||
kmf_models[group] = kmf
|
||||
|
||||
# Calculate median survival
|
||||
median, lower, upper = calculate_median_survival(kmf)
|
||||
median_survivals[group] = (median, lower, upper)
|
||||
|
||||
# Log-rank test
|
||||
if len(groups) == 2:
|
||||
group1_data = data[data[group_col] == groups[0]]
|
||||
group2_data = data[data[group_col] == groups[1]]
|
||||
|
||||
results = logrank_test(
|
||||
group1_data[time_col], group2_data[time_col],
|
||||
group1_data[event_col], group2_data[event_col]
|
||||
)
|
||||
|
||||
p_value = results.p_value
|
||||
test_statistic = results.test_statistic
|
||||
|
||||
# Add log-rank test result to plot
|
||||
ax.text(0.02, 0.15, f'Log-rank test:\np = {p_value:.4f}',
|
||||
transform=ax.transAxes, fontsize=10,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
else:
|
||||
# Multivariate log-rank for >2 groups
|
||||
results = multivariate_logrank_test(data[time_col], data[group_col], data[event_col])
|
||||
p_value = results.p_value
|
||||
test_statistic = results.test_statistic
|
||||
|
||||
ax.text(0.02, 0.15, f'Log-rank test:\np = {p_value:.4f}\n({len(groups)} groups)',
|
||||
transform=ax.transAxes, fontsize=10,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
|
||||
# Add median survival annotations
|
||||
y_pos = 0.95
|
||||
for group, (median, lower, upper) in median_survivals.items():
|
||||
if median is not None:
|
||||
ax.text(0.98, y_pos, f'{group}: {median:.1f} months (95% CI {lower:.1f}-{upper:.1f})',
|
||||
transform=ax.transAxes, fontsize=9, ha='right',
|
||||
verticalalignment='top')
|
||||
else:
|
||||
ax.text(0.98, y_pos, f'{group}: Not reached',
|
||||
transform=ax.transAxes, fontsize=9, ha='right',
|
||||
verticalalignment='top')
|
||||
y_pos -= 0.05
|
||||
|
||||
# Formatting
|
||||
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
|
||||
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
|
||||
ax.set_title(title, fontsize=14, fontweight='bold', pad=15)
|
||||
ax.legend(loc='lower left', frameon=True, fontsize=10)
|
||||
ax.grid(True, alpha=0.3, linestyle='--')
|
||||
ax.set_ylim([0, 1.05])
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Survival curve saved to: {output_path}")
|
||||
|
||||
# Also save as PNG for easy viewing
|
||||
png_path = Path(output_path).with_suffix('.png')
|
||||
plt.savefig(png_path, dpi=300, bbox_inches='tight')
|
||||
print(f"PNG version saved to: {png_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
return kmf_models, p_value
|
||||
|
||||
|
||||
def generate_number_at_risk_table(data, time_col='time', event_col='event',
|
||||
group_col='group', time_points=None):
|
||||
"""
|
||||
Generate number at risk table for survival analysis.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame with survival data
|
||||
time_points: List of time points for risk table (if None, auto-generate)
|
||||
|
||||
Returns:
|
||||
DataFrame with number at risk at each time point
|
||||
"""
|
||||
|
||||
if time_points is None:
|
||||
# Auto-generate time points (every 6 months up to max time)
|
||||
max_time = data[time_col].max()
|
||||
time_points = np.arange(0, max_time + 6, 6)
|
||||
|
||||
groups = data[group_col].unique()
|
||||
risk_table = pd.DataFrame(index=time_points, columns=groups)
|
||||
|
||||
for group in groups:
|
||||
group_data = data[data[group_col] == group]
|
||||
|
||||
for t in time_points:
|
||||
# Number at risk = patients who haven't had event and haven't been censored before time t
|
||||
at_risk = len(group_data[group_data[time_col] >= t])
|
||||
risk_table.loc[t, group] = at_risk
|
||||
|
||||
return risk_table
|
||||
|
||||
|
||||
def calculate_hazard_ratio(data, time_col='time', event_col='event', group_col='group',
|
||||
reference_group=None):
|
||||
"""
|
||||
Calculate hazard ratio using Cox proportional hazards regression.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame
|
||||
reference_group: Reference group for comparison (if None, uses first group)
|
||||
|
||||
Returns:
|
||||
Hazard ratio, 95% CI, p-value
|
||||
"""
|
||||
|
||||
# Encode group as binary for Cox regression
|
||||
groups = data[group_col].unique()
|
||||
if len(groups) != 2:
|
||||
print("Warning: Cox HR calculation assumes 2 groups. Using first 2 groups.")
|
||||
groups = groups[:2]
|
||||
|
||||
if reference_group is None:
|
||||
reference_group = groups[0]
|
||||
|
||||
# Create binary indicator (1 for comparison group, 0 for reference)
|
||||
data_cox = data.copy()
|
||||
data_cox['group_binary'] = (data_cox[group_col] != reference_group).astype(int)
|
||||
|
||||
# Fit Cox model
|
||||
cph = CoxPHFitter()
|
||||
cph.fit(data_cox[[time_col, event_col, 'group_binary']],
|
||||
duration_col=time_col, event_col=event_col)
|
||||
|
||||
# Extract results
|
||||
hr = np.exp(cph.params_['group_binary'])
|
||||
ci = np.exp(cph.confidence_intervals_.loc['group_binary'].values)
|
||||
p_value = cph.summary.loc['group_binary', 'p']
|
||||
|
||||
return hr, ci[0], ci[1], p_value
|
||||
|
||||
|
||||
def generate_report(data, output_dir, prefix='survival'):
|
||||
"""
|
||||
Generate comprehensive survival analysis report.
|
||||
|
||||
Creates:
|
||||
- Kaplan-Meier curves (PDF and PNG)
|
||||
- Number at risk table (CSV)
|
||||
- Statistical summary (TXT)
|
||||
- LaTeX table code (TEX)
|
||||
"""
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate survival curve
|
||||
kmf_models, logrank_p = generate_kaplan_meier_plot(
|
||||
data,
|
||||
output_path=output_dir / f'{prefix}_kaplan_meier.pdf',
|
||||
title='Survival Analysis by Group'
|
||||
)
|
||||
|
||||
# Number at risk table
|
||||
risk_table = generate_number_at_risk_table(data)
|
||||
risk_table.to_csv(output_dir / f'{prefix}_number_at_risk.csv')
|
||||
|
||||
# Calculate hazard ratio
|
||||
hr, ci_lower, ci_upper, hr_p = calculate_hazard_ratio(data)
|
||||
|
||||
# Generate statistical summary
|
||||
with open(output_dir / f'{prefix}_statistics.txt', 'w') as f:
|
||||
f.write("SURVIVAL ANALYSIS STATISTICAL SUMMARY\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
groups = data['group'].unique()
|
||||
for group in groups:
|
||||
kmf = kmf_models[group]
|
||||
median = kmf.median_survival_time_
|
||||
|
||||
# Calculate survival rates at common time points
|
||||
try:
|
||||
surv_12m = kmf.survival_function_at_times(12).values[0]
|
||||
surv_24m = kmf.survival_function_at_times(24).values[0] if data['time'].max() >= 24 else None
|
||||
except:
|
||||
surv_12m = None
|
||||
surv_24m = None
|
||||
|
||||
f.write(f"Group: {group}\n")
|
||||
f.write(f" N = {len(data[data['group'] == group])}\n")
|
||||
f.write(f" Events = {data[data['group'] == group]['event'].sum()}\n")
|
||||
f.write(f" Median survival: {median:.1f} months\n" if median != np.inf else " Median survival: Not reached\n")
|
||||
if surv_12m is not None:
|
||||
f.write(f" 12-month survival rate: {surv_12m*100:.1f}%\n")
|
||||
if surv_24m is not None:
|
||||
f.write(f" 24-month survival rate: {surv_24m*100:.1f}%\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write(f"Log-Rank Test:\n")
|
||||
f.write(f" p-value = {logrank_p:.4f}\n")
|
||||
f.write(f" Interpretation: {'Significant' if logrank_p < 0.05 else 'Not significant'} difference in survival\n\n")
|
||||
|
||||
if len(groups) == 2:
|
||||
f.write(f"Hazard Ratio ({groups[1]} vs {groups[0]}):\n")
|
||||
f.write(f" HR = {hr:.2f} (95% CI {ci_lower:.2f}-{ci_upper:.2f})\n")
|
||||
f.write(f" p-value = {hr_p:.4f}\n")
|
||||
f.write(f" Interpretation: {groups[1]} has {((1-hr)*100):.0f}% {'reduction' if hr < 1 else 'increase'} in risk\n")
|
||||
|
||||
# Generate LaTeX table code
|
||||
with open(output_dir / f'{prefix}_latex_table.tex', 'w') as f:
|
||||
f.write("% LaTeX table code for survival outcomes\n")
|
||||
f.write("\\begin{table}[H]\n")
|
||||
f.write("\\centering\n")
|
||||
f.write("\\small\n")
|
||||
f.write("\\begin{tabular}{lcccc}\n")
|
||||
f.write("\\toprule\n")
|
||||
f.write("\\textbf{Endpoint} & \\textbf{Group A} & \\textbf{Group B} & \\textbf{HR (95\\% CI)} & \\textbf{p-value} \\\\\n")
|
||||
f.write("\\midrule\n")
|
||||
|
||||
# Add median survival row
|
||||
for i, group in enumerate(groups):
|
||||
kmf = kmf_models[group]
|
||||
median = kmf.median_survival_time_
|
||||
if i == 0:
|
||||
f.write(f"Median survival, months (95\\% CI) & ")
|
||||
if median != np.inf:
|
||||
f.write(f"{median:.1f} & ")
|
||||
else:
|
||||
f.write("NR & ")
|
||||
else:
|
||||
if median != np.inf:
|
||||
f.write(f"{median:.1f} & ")
|
||||
else:
|
||||
f.write("NR & ")
|
||||
|
||||
f.write(f"{hr:.2f} ({ci_lower:.2f}-{ci_upper:.2f}) & {hr_p:.3f} \\\\\n")
|
||||
|
||||
# Add 12-month survival rate
|
||||
f.write("12-month survival rate (\\%) & ")
|
||||
for group in groups:
|
||||
kmf = kmf_models[group]
|
||||
try:
|
||||
surv_12m = kmf.survival_function_at_times(12).values[0]
|
||||
f.write(f"{surv_12m*100:.0f}\\% & ")
|
||||
except:
|
||||
f.write("-- & ")
|
||||
f.write("-- & -- \\\\\n")
|
||||
|
||||
f.write("\\bottomrule\n")
|
||||
f.write("\\end{tabular}\n")
|
||||
f.write(f"\\caption{{Survival outcomes by group (log-rank p={logrank_p:.3f})}}\n")
|
||||
f.write("\\end{table}\n")
|
||||
|
||||
print(f"\nAnalysis complete! Files saved to {output_dir}/")
|
||||
print(f" - Survival curves: {prefix}_kaplan_meier.pdf/png")
|
||||
print(f" - Statistics: {prefix}_statistics.txt")
|
||||
print(f" - LaTeX table: {prefix}_latex_table.tex")
|
||||
print(f" - Risk table: {prefix}_number_at_risk.csv")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate Kaplan-Meier survival curves')
|
||||
parser.add_argument('input_file', type=str, help='CSV file with survival data')
|
||||
parser.add_argument('-o', '--output', type=str, default='survival_output',
|
||||
help='Output directory (default: survival_output)')
|
||||
parser.add_argument('-t', '--title', type=str, default='Kaplan-Meier Survival Curve',
|
||||
help='Plot title')
|
||||
parser.add_argument('-x', '--xlabel', type=str, default='Time (months)',
|
||||
help='X-axis label')
|
||||
parser.add_argument('-y', '--ylabel', type=str, default='Survival Probability',
|
||||
help='Y-axis label')
|
||||
parser.add_argument('--time-col', type=str, default='time',
|
||||
help='Column name for time variable')
|
||||
parser.add_argument('--event-col', type=str, default='event',
|
||||
help='Column name for event indicator')
|
||||
parser.add_argument('--group-col', type=str, default='group',
|
||||
help='Column name for grouping variable')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load data
|
||||
print(f"Loading data from {args.input_file}...")
|
||||
data = load_survival_data(args.input_file)
|
||||
print(f"Loaded {len(data)} patients")
|
||||
print(f"Groups: {data[args.group_col].value_counts().to_dict()}")
|
||||
|
||||
# Generate analysis
|
||||
generate_report(
|
||||
data,
|
||||
output_dir=args.output,
|
||||
prefix='survival'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
# Example usage:
|
||||
# python generate_survival_analysis.py survival_data.csv -o figures/ -t "PFS by PD-L1 Status"
|
||||
#
|
||||
# Input CSV format:
|
||||
# patient_id,time,event,group
|
||||
# PT001,12.3,1,PD-L1+
|
||||
# PT002,8.5,1,PD-L1-
|
||||
# PT003,18.2,0,PD-L1+
|
||||
# ...
|
||||
|
||||
335
skills/clinical-decision-support/scripts/validate_cds_document.py
Executable file
335
skills/clinical-decision-support/scripts/validate_cds_document.py
Executable file
@@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Validate Clinical Decision Support Documents for Quality and Completeness
|
||||
|
||||
Checks for:
|
||||
- Evidence citations for all recommendations
|
||||
- Statistical reporting completeness
|
||||
- Biomarker nomenclature consistency
|
||||
- Required sections present
|
||||
- HIPAA de-identification
|
||||
- GRADE recommendation format
|
||||
|
||||
Dependencies: None (pure Python)
|
||||
"""
|
||||
|
||||
import re
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class CDSValidator:
|
||||
"""Validator for clinical decision support documents."""
|
||||
|
||||
def __init__(self, filepath):
|
||||
self.filepath = filepath
|
||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
self.content = f.read()
|
||||
|
||||
self.errors = []
|
||||
self.warnings = []
|
||||
self.info = []
|
||||
|
||||
def validate_all(self):
|
||||
"""Run all validation checks."""
|
||||
|
||||
print(f"Validating: {self.filepath}")
|
||||
print("="*70)
|
||||
|
||||
self.check_required_sections()
|
||||
self.check_evidence_citations()
|
||||
self.check_recommendation_grading()
|
||||
self.check_statistical_reporting()
|
||||
self.check_hipaa_identifiers()
|
||||
self.check_biomarker_nomenclature()
|
||||
|
||||
return self.generate_report()
|
||||
|
||||
def check_required_sections(self):
|
||||
"""Check if required sections are present."""
|
||||
|
||||
# Cohort analysis required sections
|
||||
cohort_sections = [
|
||||
'cohort characteristics',
|
||||
'biomarker',
|
||||
'outcomes',
|
||||
'statistical analysis',
|
||||
'clinical implications',
|
||||
'references'
|
||||
]
|
||||
|
||||
# Treatment recommendation required sections
|
||||
rec_sections = [
|
||||
'evidence',
|
||||
'recommendation',
|
||||
'monitoring',
|
||||
'references'
|
||||
]
|
||||
|
||||
content_lower = self.content.lower()
|
||||
|
||||
# Check which document type
|
||||
is_cohort = 'cohort' in content_lower
|
||||
is_recommendation = 'recommendation' in content_lower
|
||||
|
||||
if is_cohort:
|
||||
missing = [sec for sec in cohort_sections if sec not in content_lower]
|
||||
if missing:
|
||||
self.warnings.append(f"Cohort analysis may be missing sections: {', '.join(missing)}")
|
||||
else:
|
||||
self.info.append("All cohort analysis sections present")
|
||||
|
||||
if is_recommendation:
|
||||
missing = [sec for sec in rec_sections if sec not in content_lower]
|
||||
if missing:
|
||||
self.errors.append(f"Recommendation document missing required sections: {', '.join(missing)}")
|
||||
else:
|
||||
self.info.append("All recommendation sections present")
|
||||
|
||||
def check_evidence_citations(self):
|
||||
"""Check that recommendations have citations."""
|
||||
|
||||
# Find recommendation statements
|
||||
rec_pattern = r'(recommend|should|prefer|suggest|consider)(.*?)(?:\n\n|\Z)'
|
||||
recommendations = re.findall(rec_pattern, self.content, re.IGNORECASE | re.DOTALL)
|
||||
|
||||
# Find citations
|
||||
citation_patterns = [
|
||||
r'\[\d+\]', # Numbered citations [1]
|
||||
r'\(.*?\d{4}\)', # Author year (Smith 2020)
|
||||
r'et al\.', # Et al citations
|
||||
r'NCCN|ASCO|ESMO', # Guideline references
|
||||
]
|
||||
|
||||
uncited_recommendations = []
|
||||
|
||||
for i, (_, rec_text) in enumerate(recommendations):
|
||||
has_citation = any(re.search(pattern, rec_text) for pattern in citation_patterns)
|
||||
|
||||
if not has_citation:
|
||||
snippet = rec_text[:60].strip() + '...'
|
||||
uncited_recommendations.append(snippet)
|
||||
|
||||
if uncited_recommendations:
|
||||
self.warnings.append(f"Found {len(uncited_recommendations)} recommendations without citations")
|
||||
for rec in uncited_recommendations[:3]: # Show first 3
|
||||
self.warnings.append(f" - {rec}")
|
||||
else:
|
||||
self.info.append(f"All {len(recommendations)} recommendations have citations")
|
||||
|
||||
def check_recommendation_grading(self):
|
||||
"""Check for GRADE-style recommendation strength."""
|
||||
|
||||
# Look for GRADE notation (1A, 1B, 2A, 2B, 2C)
|
||||
grade_pattern = r'GRADE\s*[12][A-C]|Grade\s*[12][A-C]|\(?\s*[12][A-C]\s*\)?'
|
||||
grades = re.findall(grade_pattern, self.content, re.IGNORECASE)
|
||||
|
||||
# Look for strong/conditional language
|
||||
strong_pattern = r'(strong|we recommend|should)'
|
||||
conditional_pattern = r'(conditional|weak|we suggest|may consider|could consider)'
|
||||
|
||||
strong_count = len(re.findall(strong_pattern, self.content, re.IGNORECASE))
|
||||
conditional_count = len(re.findall(conditional_pattern, self.content, re.IGNORECASE))
|
||||
|
||||
if grades:
|
||||
self.info.append(f"Found {len(grades)} GRADE-style recommendations")
|
||||
else:
|
||||
self.warnings.append("No GRADE-style recommendation grading found (1A, 1B, 2A, etc.)")
|
||||
|
||||
if strong_count > 0 or conditional_count > 0:
|
||||
self.info.append(f"Recommendation language: {strong_count} strong, {conditional_count} conditional")
|
||||
else:
|
||||
self.warnings.append("No clear recommendation strength language (strong/conditional) found")
|
||||
|
||||
def check_statistical_reporting(self):
|
||||
"""Check for proper statistical reporting."""
|
||||
|
||||
# Check for p-values
|
||||
p_values = re.findall(r'p\s*[=<>]\s*[\d.]+', self.content, re.IGNORECASE)
|
||||
|
||||
# Check for confidence intervals
|
||||
ci_pattern = r'95%\s*CI|confidence interval'
|
||||
cis = re.findall(ci_pattern, self.content, re.IGNORECASE)
|
||||
|
||||
# Check for hazard ratios
|
||||
hr_pattern = r'HR\s*[=:]\s*[\d.]+'
|
||||
hrs = re.findall(hr_pattern, self.content)
|
||||
|
||||
# Check for sample sizes
|
||||
n_pattern = r'n\s*=\s*\d+'
|
||||
sample_sizes = re.findall(n_pattern, self.content, re.IGNORECASE)
|
||||
|
||||
if not p_values:
|
||||
self.warnings.append("No p-values found - statistical significance not reported")
|
||||
else:
|
||||
self.info.append(f"Found {len(p_values)} p-values")
|
||||
|
||||
if hrs and not cis:
|
||||
self.warnings.append("Hazard ratios reported without confidence intervals")
|
||||
|
||||
if not sample_sizes:
|
||||
self.warnings.append("Sample sizes (n=X) not clearly reported")
|
||||
|
||||
# Check for common statistical errors
|
||||
if 'p=0.00' in self.content or 'p = 0.00' in self.content:
|
||||
self.warnings.append("Found p=0.00 (should report as p<0.001 instead)")
|
||||
|
||||
def check_hipaa_identifiers(self):
|
||||
"""Check for potential HIPAA identifiers."""
|
||||
|
||||
# 18 HIPAA identifiers (simplified check for common ones)
|
||||
identifiers = {
|
||||
'Names': r'Dr\.\s+[A-Z][a-z]+|Patient:\s*[A-Z][a-z]+',
|
||||
'Specific dates': r'\d{1,2}/\d{1,2}/\d{4}', # MM/DD/YYYY
|
||||
'Phone numbers': r'\d{3}[-.]?\d{3}[-.]?\d{4}',
|
||||
'Email addresses': r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}',
|
||||
'SSN': r'\d{3}-\d{2}-\d{4}',
|
||||
'MRN': r'MRN\s*:?\s*\d+',
|
||||
}
|
||||
|
||||
found_identifiers = []
|
||||
|
||||
for identifier_type, pattern in identifiers.items():
|
||||
matches = re.findall(pattern, self.content)
|
||||
if matches:
|
||||
found_identifiers.append(f"{identifier_type}: {len(matches)} instance(s)")
|
||||
|
||||
if found_identifiers:
|
||||
self.errors.append("Potential HIPAA identifiers detected:")
|
||||
for identifier in found_identifiers:
|
||||
self.errors.append(f" - {identifier}")
|
||||
self.errors.append(" ** Ensure proper de-identification before distribution **")
|
||||
else:
|
||||
self.info.append("No obvious HIPAA identifiers detected (basic check only)")
|
||||
|
||||
def check_biomarker_nomenclature(self):
|
||||
"""Check for consistent biomarker nomenclature."""
|
||||
|
||||
# Common biomarker naming issues
|
||||
issues = []
|
||||
|
||||
# Check for gene names (should be italicized in LaTeX)
|
||||
gene_names = ['EGFR', 'ALK', 'ROS1', 'BRAF', 'KRAS', 'HER2', 'TP53', 'BRCA1', 'BRCA2']
|
||||
for gene in gene_names:
|
||||
# Check if gene appears but not in italics (\textit{} or \emph{})
|
||||
if gene in self.content:
|
||||
if f'\\textit{{{gene}}}' not in self.content and f'\\emph{{{gene}}}' not in self.content:
|
||||
if '.tex' in self.filepath.suffix:
|
||||
issues.append(f"{gene} should be italicized in LaTeX (\\textit{{{gene}}})")
|
||||
|
||||
# Check for protein vs gene naming
|
||||
# HER2 (protein) vs ERBB2 (gene) - both valid
|
||||
# Check for mutation nomenclature (HGVS format)
|
||||
hgvs_pattern = r'p\.[A-Z]\d+[A-Z]' # e.g., p.L858R
|
||||
hgvs_mutations = re.findall(hgvs_pattern, self.content)
|
||||
|
||||
if hgvs_mutations:
|
||||
self.info.append(f"Found {len(hgvs_mutations)} HGVS protein nomenclature (e.g., p.L858R)")
|
||||
|
||||
# Warn about non-standard mutation format
|
||||
if 'EGFR mutation' in self.content and 'exon' not in self.content.lower():
|
||||
self.warnings.append("EGFR mutation mentioned - specify exon/variant (e.g., exon 19 deletion)")
|
||||
|
||||
if issues:
|
||||
self.warnings.extend(issues)
|
||||
|
||||
def generate_report(self):
|
||||
"""Generate validation report."""
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("VALIDATION REPORT")
|
||||
print("="*70)
|
||||
|
||||
if self.errors:
|
||||
print(f"\n❌ ERRORS ({len(self.errors)}):")
|
||||
for error in self.errors:
|
||||
print(f" {error}")
|
||||
|
||||
if self.warnings:
|
||||
print(f"\n⚠️ WARNINGS ({len(self.warnings)}):")
|
||||
for warning in self.warnings:
|
||||
print(f" {warning}")
|
||||
|
||||
if self.info:
|
||||
print(f"\n✓ PASSED CHECKS ({len(self.info)}):")
|
||||
for info in self.info:
|
||||
print(f" {info}")
|
||||
|
||||
# Overall status
|
||||
print("\n" + "="*70)
|
||||
if self.errors:
|
||||
print("STATUS: ❌ VALIDATION FAILED - Address errors before distribution")
|
||||
return False
|
||||
elif self.warnings:
|
||||
print("STATUS: ⚠️ VALIDATION PASSED WITH WARNINGS - Review recommended")
|
||||
return True
|
||||
else:
|
||||
print("STATUS: ✓ VALIDATION PASSED - Document meets quality standards")
|
||||
return True
|
||||
|
||||
def save_report(self, output_file):
|
||||
"""Save validation report to file."""
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write("CLINICAL DECISION SUPPORT DOCUMENT VALIDATION REPORT\n")
|
||||
f.write("="*70 + "\n")
|
||||
f.write(f"Document: {self.filepath}\n")
|
||||
f.write(f"Validated: {Path.cwd()}\n\n")
|
||||
|
||||
if self.errors:
|
||||
f.write(f"ERRORS ({len(self.errors)}):\n")
|
||||
for error in self.errors:
|
||||
f.write(f" - {error}\n")
|
||||
f.write("\n")
|
||||
|
||||
if self.warnings:
|
||||
f.write(f"WARNINGS ({len(self.warnings)}):\n")
|
||||
for warning in self.warnings:
|
||||
f.write(f" - {warning}\n")
|
||||
f.write("\n")
|
||||
|
||||
if self.info:
|
||||
f.write(f"PASSED CHECKS ({len(self.info)}):\n")
|
||||
for info in self.info:
|
||||
f.write(f" - {info}\n")
|
||||
|
||||
print(f"\nValidation report saved to: {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Validate clinical decision support documents')
|
||||
parser.add_argument('input_file', type=str, help='Document to validate (.tex, .md, .txt)')
|
||||
parser.add_argument('-o', '--output', type=str, default=None,
|
||||
help='Save validation report to file')
|
||||
parser.add_argument('--strict', action='store_true',
|
||||
help='Treat warnings as errors')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate
|
||||
validator = CDSValidator(args.input_file)
|
||||
passed = validator.validate_all()
|
||||
|
||||
# Save report if requested
|
||||
if args.output:
|
||||
validator.save_report(args.output)
|
||||
|
||||
# Exit code
|
||||
if args.strict and (validator.errors or validator.warnings):
|
||||
exit(1)
|
||||
elif validator.errors:
|
||||
exit(1)
|
||||
else:
|
||||
exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
# Example usage:
|
||||
# python validate_cds_document.py cohort_analysis.tex
|
||||
# python validate_cds_document.py treatment_recommendations.tex -o validation_report.txt
|
||||
# python validate_cds_document.py document.tex --strict # Warnings cause failure
|
||||
|
||||
Reference in New Issue
Block a user