Initial commit
This commit is contained in:
422
skills/clinical-decision-support/scripts/generate_survival_analysis.py
Executable file
422
skills/clinical-decision-support/scripts/generate_survival_analysis.py
Executable file
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate Kaplan-Meier Survival Curves for Clinical Decision Support Documents
|
||||
|
||||
This script creates publication-quality survival curves with:
|
||||
- Kaplan-Meier survival estimates
|
||||
- 95% confidence intervals
|
||||
- Log-rank test statistics
|
||||
- Hazard ratios with confidence intervals
|
||||
- Number at risk tables
|
||||
- Median survival annotations
|
||||
|
||||
Dependencies: lifelines, matplotlib, pandas, numpy
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from lifelines import KaplanMeierFitter
|
||||
from lifelines.statistics import logrank_test, multivariate_logrank_test
|
||||
from lifelines import CoxPHFitter
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_survival_data(filepath):
|
||||
"""
|
||||
Load survival data from CSV file.
|
||||
|
||||
Expected columns:
|
||||
- patient_id: Unique patient identifier
|
||||
- time: Survival time (months or days)
|
||||
- event: Event indicator (1=event occurred, 0=censored)
|
||||
- group: Stratification variable (e.g., 'Biomarker+', 'Biomarker-')
|
||||
- Optional: Additional covariates for Cox regression
|
||||
|
||||
Returns:
|
||||
pandas.DataFrame
|
||||
"""
|
||||
df = pd.read_csv(filepath)
|
||||
|
||||
# Validate required columns
|
||||
required_cols = ['patient_id', 'time', 'event', 'group']
|
||||
missing = [col for col in required_cols if col not in df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"Missing required columns: {missing}")
|
||||
|
||||
# Convert event to boolean if needed
|
||||
df['event'] = df['event'].astype(bool)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def calculate_median_survival(kmf):
|
||||
"""Calculate median survival with 95% CI."""
|
||||
median = kmf.median_survival_time_
|
||||
ci = kmf.confidence_interval_survival_function_
|
||||
|
||||
# Find time when survival crosses 0.5
|
||||
if median == np.inf:
|
||||
return None, None, None
|
||||
|
||||
# Get CI at median
|
||||
idx = np.argmin(np.abs(kmf.survival_function_.index - median))
|
||||
lower_ci = ci.iloc[idx]['KM_estimate_lower_0.95']
|
||||
upper_ci = ci.iloc[idx]['KM_estimate_upper_0.95']
|
||||
|
||||
return median, lower_ci, upper_ci
|
||||
|
||||
|
||||
def generate_kaplan_meier_plot(data, time_col='time', event_col='event',
|
||||
group_col='group', output_path='survival_curve.pdf',
|
||||
title='Kaplan-Meier Survival Curve',
|
||||
xlabel='Time (months)', ylabel='Survival Probability'):
|
||||
"""
|
||||
Generate Kaplan-Meier survival curve comparing groups.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame with survival data
|
||||
time_col: Column name for survival time
|
||||
event_col: Column name for event indicator
|
||||
group_col: Column name for stratification
|
||||
output_path: Path to save figure
|
||||
title: Plot title
|
||||
xlabel: X-axis label (specify units)
|
||||
ylabel: Y-axis label
|
||||
"""
|
||||
|
||||
# Create figure and axis
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# Get unique groups
|
||||
groups = data[group_col].unique()
|
||||
|
||||
# Colors for groups (colorblind-friendly)
|
||||
colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', '#CA9161']
|
||||
|
||||
kmf_models = {}
|
||||
median_survivals = {}
|
||||
|
||||
# Plot each group
|
||||
for i, group in enumerate(groups):
|
||||
group_data = data[data[group_col] == group]
|
||||
|
||||
# Fit Kaplan-Meier
|
||||
kmf = KaplanMeierFitter()
|
||||
kmf.fit(group_data[time_col], group_data[event_col], label=str(group))
|
||||
|
||||
# Plot survival curve
|
||||
kmf.plot_survival_function(ax=ax, ci_show=True, color=colors[i % len(colors)],
|
||||
linewidth=2, alpha=0.8)
|
||||
|
||||
# Store model
|
||||
kmf_models[group] = kmf
|
||||
|
||||
# Calculate median survival
|
||||
median, lower, upper = calculate_median_survival(kmf)
|
||||
median_survivals[group] = (median, lower, upper)
|
||||
|
||||
# Log-rank test
|
||||
if len(groups) == 2:
|
||||
group1_data = data[data[group_col] == groups[0]]
|
||||
group2_data = data[data[group_col] == groups[1]]
|
||||
|
||||
results = logrank_test(
|
||||
group1_data[time_col], group2_data[time_col],
|
||||
group1_data[event_col], group2_data[event_col]
|
||||
)
|
||||
|
||||
p_value = results.p_value
|
||||
test_statistic = results.test_statistic
|
||||
|
||||
# Add log-rank test result to plot
|
||||
ax.text(0.02, 0.15, f'Log-rank test:\np = {p_value:.4f}',
|
||||
transform=ax.transAxes, fontsize=10,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
else:
|
||||
# Multivariate log-rank for >2 groups
|
||||
results = multivariate_logrank_test(data[time_col], data[group_col], data[event_col])
|
||||
p_value = results.p_value
|
||||
test_statistic = results.test_statistic
|
||||
|
||||
ax.text(0.02, 0.15, f'Log-rank test:\np = {p_value:.4f}\n({len(groups)} groups)',
|
||||
transform=ax.transAxes, fontsize=10,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
|
||||
# Add median survival annotations
|
||||
y_pos = 0.95
|
||||
for group, (median, lower, upper) in median_survivals.items():
|
||||
if median is not None:
|
||||
ax.text(0.98, y_pos, f'{group}: {median:.1f} months (95% CI {lower:.1f}-{upper:.1f})',
|
||||
transform=ax.transAxes, fontsize=9, ha='right',
|
||||
verticalalignment='top')
|
||||
else:
|
||||
ax.text(0.98, y_pos, f'{group}: Not reached',
|
||||
transform=ax.transAxes, fontsize=9, ha='right',
|
||||
verticalalignment='top')
|
||||
y_pos -= 0.05
|
||||
|
||||
# Formatting
|
||||
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
|
||||
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
|
||||
ax.set_title(title, fontsize=14, fontweight='bold', pad=15)
|
||||
ax.legend(loc='lower left', frameon=True, fontsize=10)
|
||||
ax.grid(True, alpha=0.3, linestyle='--')
|
||||
ax.set_ylim([0, 1.05])
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Survival curve saved to: {output_path}")
|
||||
|
||||
# Also save as PNG for easy viewing
|
||||
png_path = Path(output_path).with_suffix('.png')
|
||||
plt.savefig(png_path, dpi=300, bbox_inches='tight')
|
||||
print(f"PNG version saved to: {png_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
return kmf_models, p_value
|
||||
|
||||
|
||||
def generate_number_at_risk_table(data, time_col='time', event_col='event',
|
||||
group_col='group', time_points=None):
|
||||
"""
|
||||
Generate number at risk table for survival analysis.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame with survival data
|
||||
time_points: List of time points for risk table (if None, auto-generate)
|
||||
|
||||
Returns:
|
||||
DataFrame with number at risk at each time point
|
||||
"""
|
||||
|
||||
if time_points is None:
|
||||
# Auto-generate time points (every 6 months up to max time)
|
||||
max_time = data[time_col].max()
|
||||
time_points = np.arange(0, max_time + 6, 6)
|
||||
|
||||
groups = data[group_col].unique()
|
||||
risk_table = pd.DataFrame(index=time_points, columns=groups)
|
||||
|
||||
for group in groups:
|
||||
group_data = data[data[group_col] == group]
|
||||
|
||||
for t in time_points:
|
||||
# Number at risk = patients who haven't had event and haven't been censored before time t
|
||||
at_risk = len(group_data[group_data[time_col] >= t])
|
||||
risk_table.loc[t, group] = at_risk
|
||||
|
||||
return risk_table
|
||||
|
||||
|
||||
def calculate_hazard_ratio(data, time_col='time', event_col='event', group_col='group',
|
||||
reference_group=None):
|
||||
"""
|
||||
Calculate hazard ratio using Cox proportional hazards regression.
|
||||
|
||||
Parameters:
|
||||
data: DataFrame
|
||||
reference_group: Reference group for comparison (if None, uses first group)
|
||||
|
||||
Returns:
|
||||
Hazard ratio, 95% CI, p-value
|
||||
"""
|
||||
|
||||
# Encode group as binary for Cox regression
|
||||
groups = data[group_col].unique()
|
||||
if len(groups) != 2:
|
||||
print("Warning: Cox HR calculation assumes 2 groups. Using first 2 groups.")
|
||||
groups = groups[:2]
|
||||
|
||||
if reference_group is None:
|
||||
reference_group = groups[0]
|
||||
|
||||
# Create binary indicator (1 for comparison group, 0 for reference)
|
||||
data_cox = data.copy()
|
||||
data_cox['group_binary'] = (data_cox[group_col] != reference_group).astype(int)
|
||||
|
||||
# Fit Cox model
|
||||
cph = CoxPHFitter()
|
||||
cph.fit(data_cox[[time_col, event_col, 'group_binary']],
|
||||
duration_col=time_col, event_col=event_col)
|
||||
|
||||
# Extract results
|
||||
hr = np.exp(cph.params_['group_binary'])
|
||||
ci = np.exp(cph.confidence_intervals_.loc['group_binary'].values)
|
||||
p_value = cph.summary.loc['group_binary', 'p']
|
||||
|
||||
return hr, ci[0], ci[1], p_value
|
||||
|
||||
|
||||
def generate_report(data, output_dir, prefix='survival'):
|
||||
"""
|
||||
Generate comprehensive survival analysis report.
|
||||
|
||||
Creates:
|
||||
- Kaplan-Meier curves (PDF and PNG)
|
||||
- Number at risk table (CSV)
|
||||
- Statistical summary (TXT)
|
||||
- LaTeX table code (TEX)
|
||||
"""
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate survival curve
|
||||
kmf_models, logrank_p = generate_kaplan_meier_plot(
|
||||
data,
|
||||
output_path=output_dir / f'{prefix}_kaplan_meier.pdf',
|
||||
title='Survival Analysis by Group'
|
||||
)
|
||||
|
||||
# Number at risk table
|
||||
risk_table = generate_number_at_risk_table(data)
|
||||
risk_table.to_csv(output_dir / f'{prefix}_number_at_risk.csv')
|
||||
|
||||
# Calculate hazard ratio
|
||||
hr, ci_lower, ci_upper, hr_p = calculate_hazard_ratio(data)
|
||||
|
||||
# Generate statistical summary
|
||||
with open(output_dir / f'{prefix}_statistics.txt', 'w') as f:
|
||||
f.write("SURVIVAL ANALYSIS STATISTICAL SUMMARY\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
groups = data['group'].unique()
|
||||
for group in groups:
|
||||
kmf = kmf_models[group]
|
||||
median = kmf.median_survival_time_
|
||||
|
||||
# Calculate survival rates at common time points
|
||||
try:
|
||||
surv_12m = kmf.survival_function_at_times(12).values[0]
|
||||
surv_24m = kmf.survival_function_at_times(24).values[0] if data['time'].max() >= 24 else None
|
||||
except:
|
||||
surv_12m = None
|
||||
surv_24m = None
|
||||
|
||||
f.write(f"Group: {group}\n")
|
||||
f.write(f" N = {len(data[data['group'] == group])}\n")
|
||||
f.write(f" Events = {data[data['group'] == group]['event'].sum()}\n")
|
||||
f.write(f" Median survival: {median:.1f} months\n" if median != np.inf else " Median survival: Not reached\n")
|
||||
if surv_12m is not None:
|
||||
f.write(f" 12-month survival rate: {surv_12m*100:.1f}%\n")
|
||||
if surv_24m is not None:
|
||||
f.write(f" 24-month survival rate: {surv_24m*100:.1f}%\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write(f"Log-Rank Test:\n")
|
||||
f.write(f" p-value = {logrank_p:.4f}\n")
|
||||
f.write(f" Interpretation: {'Significant' if logrank_p < 0.05 else 'Not significant'} difference in survival\n\n")
|
||||
|
||||
if len(groups) == 2:
|
||||
f.write(f"Hazard Ratio ({groups[1]} vs {groups[0]}):\n")
|
||||
f.write(f" HR = {hr:.2f} (95% CI {ci_lower:.2f}-{ci_upper:.2f})\n")
|
||||
f.write(f" p-value = {hr_p:.4f}\n")
|
||||
f.write(f" Interpretation: {groups[1]} has {((1-hr)*100):.0f}% {'reduction' if hr < 1 else 'increase'} in risk\n")
|
||||
|
||||
# Generate LaTeX table code
|
||||
with open(output_dir / f'{prefix}_latex_table.tex', 'w') as f:
|
||||
f.write("% LaTeX table code for survival outcomes\n")
|
||||
f.write("\\begin{table}[H]\n")
|
||||
f.write("\\centering\n")
|
||||
f.write("\\small\n")
|
||||
f.write("\\begin{tabular}{lcccc}\n")
|
||||
f.write("\\toprule\n")
|
||||
f.write("\\textbf{Endpoint} & \\textbf{Group A} & \\textbf{Group B} & \\textbf{HR (95\\% CI)} & \\textbf{p-value} \\\\\n")
|
||||
f.write("\\midrule\n")
|
||||
|
||||
# Add median survival row
|
||||
for i, group in enumerate(groups):
|
||||
kmf = kmf_models[group]
|
||||
median = kmf.median_survival_time_
|
||||
if i == 0:
|
||||
f.write(f"Median survival, months (95\\% CI) & ")
|
||||
if median != np.inf:
|
||||
f.write(f"{median:.1f} & ")
|
||||
else:
|
||||
f.write("NR & ")
|
||||
else:
|
||||
if median != np.inf:
|
||||
f.write(f"{median:.1f} & ")
|
||||
else:
|
||||
f.write("NR & ")
|
||||
|
||||
f.write(f"{hr:.2f} ({ci_lower:.2f}-{ci_upper:.2f}) & {hr_p:.3f} \\\\\n")
|
||||
|
||||
# Add 12-month survival rate
|
||||
f.write("12-month survival rate (\\%) & ")
|
||||
for group in groups:
|
||||
kmf = kmf_models[group]
|
||||
try:
|
||||
surv_12m = kmf.survival_function_at_times(12).values[0]
|
||||
f.write(f"{surv_12m*100:.0f}\\% & ")
|
||||
except:
|
||||
f.write("-- & ")
|
||||
f.write("-- & -- \\\\\n")
|
||||
|
||||
f.write("\\bottomrule\n")
|
||||
f.write("\\end{tabular}\n")
|
||||
f.write(f"\\caption{{Survival outcomes by group (log-rank p={logrank_p:.3f})}}\n")
|
||||
f.write("\\end{table}\n")
|
||||
|
||||
print(f"\nAnalysis complete! Files saved to {output_dir}/")
|
||||
print(f" - Survival curves: {prefix}_kaplan_meier.pdf/png")
|
||||
print(f" - Statistics: {prefix}_statistics.txt")
|
||||
print(f" - LaTeX table: {prefix}_latex_table.tex")
|
||||
print(f" - Risk table: {prefix}_number_at_risk.csv")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate Kaplan-Meier survival curves')
|
||||
parser.add_argument('input_file', type=str, help='CSV file with survival data')
|
||||
parser.add_argument('-o', '--output', type=str, default='survival_output',
|
||||
help='Output directory (default: survival_output)')
|
||||
parser.add_argument('-t', '--title', type=str, default='Kaplan-Meier Survival Curve',
|
||||
help='Plot title')
|
||||
parser.add_argument('-x', '--xlabel', type=str, default='Time (months)',
|
||||
help='X-axis label')
|
||||
parser.add_argument('-y', '--ylabel', type=str, default='Survival Probability',
|
||||
help='Y-axis label')
|
||||
parser.add_argument('--time-col', type=str, default='time',
|
||||
help='Column name for time variable')
|
||||
parser.add_argument('--event-col', type=str, default='event',
|
||||
help='Column name for event indicator')
|
||||
parser.add_argument('--group-col', type=str, default='group',
|
||||
help='Column name for grouping variable')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load data
|
||||
print(f"Loading data from {args.input_file}...")
|
||||
data = load_survival_data(args.input_file)
|
||||
print(f"Loaded {len(data)} patients")
|
||||
print(f"Groups: {data[args.group_col].value_counts().to_dict()}")
|
||||
|
||||
# Generate analysis
|
||||
generate_report(
|
||||
data,
|
||||
output_dir=args.output,
|
||||
prefix='survival'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
# Example usage:
|
||||
# python generate_survival_analysis.py survival_data.csv -o figures/ -t "PFS by PD-L1 Status"
|
||||
#
|
||||
# Input CSV format:
|
||||
# patient_id,time,event,group
|
||||
# PT001,12.3,1,PD-L1+
|
||||
# PT002,8.5,1,PD-L1-
|
||||
# PT003,18.2,0,PD-L1+
|
||||
# ...
|
||||
|
||||
Reference in New Issue
Block a user