423 lines
15 KiB
Python
Executable File
423 lines
15 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Generate Kaplan-Meier Survival Curves for Clinical Decision Support Documents
|
|
|
|
This script creates publication-quality survival curves with:
|
|
- Kaplan-Meier survival estimates
|
|
- 95% confidence intervals
|
|
- Log-rank test statistics
|
|
- Hazard ratios with confidence intervals
|
|
- Number at risk tables
|
|
- Median survival annotations
|
|
|
|
Dependencies: lifelines, matplotlib, pandas, numpy
|
|
"""
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from lifelines import KaplanMeierFitter
|
|
from lifelines.statistics import logrank_test, multivariate_logrank_test
|
|
from lifelines import CoxPHFitter
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
|
|
def load_survival_data(filepath):
|
|
"""
|
|
Load survival data from CSV file.
|
|
|
|
Expected columns:
|
|
- patient_id: Unique patient identifier
|
|
- time: Survival time (months or days)
|
|
- event: Event indicator (1=event occurred, 0=censored)
|
|
- group: Stratification variable (e.g., 'Biomarker+', 'Biomarker-')
|
|
- Optional: Additional covariates for Cox regression
|
|
|
|
Returns:
|
|
pandas.DataFrame
|
|
"""
|
|
df = pd.read_csv(filepath)
|
|
|
|
# Validate required columns
|
|
required_cols = ['patient_id', 'time', 'event', 'group']
|
|
missing = [col for col in required_cols if col not in df.columns]
|
|
if missing:
|
|
raise ValueError(f"Missing required columns: {missing}")
|
|
|
|
# Convert event to boolean if needed
|
|
df['event'] = df['event'].astype(bool)
|
|
|
|
return df
|
|
|
|
|
|
def calculate_median_survival(kmf):
|
|
"""Calculate median survival with 95% CI."""
|
|
median = kmf.median_survival_time_
|
|
ci = kmf.confidence_interval_survival_function_
|
|
|
|
# Find time when survival crosses 0.5
|
|
if median == np.inf:
|
|
return None, None, None
|
|
|
|
# Get CI at median
|
|
idx = np.argmin(np.abs(kmf.survival_function_.index - median))
|
|
lower_ci = ci.iloc[idx]['KM_estimate_lower_0.95']
|
|
upper_ci = ci.iloc[idx]['KM_estimate_upper_0.95']
|
|
|
|
return median, lower_ci, upper_ci
|
|
|
|
|
|
def generate_kaplan_meier_plot(data, time_col='time', event_col='event',
|
|
group_col='group', output_path='survival_curve.pdf',
|
|
title='Kaplan-Meier Survival Curve',
|
|
xlabel='Time (months)', ylabel='Survival Probability'):
|
|
"""
|
|
Generate Kaplan-Meier survival curve comparing groups.
|
|
|
|
Parameters:
|
|
data: DataFrame with survival data
|
|
time_col: Column name for survival time
|
|
event_col: Column name for event indicator
|
|
group_col: Column name for stratification
|
|
output_path: Path to save figure
|
|
title: Plot title
|
|
xlabel: X-axis label (specify units)
|
|
ylabel: Y-axis label
|
|
"""
|
|
|
|
# Create figure and axis
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
|
|
|
# Get unique groups
|
|
groups = data[group_col].unique()
|
|
|
|
# Colors for groups (colorblind-friendly)
|
|
colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', '#CA9161']
|
|
|
|
kmf_models = {}
|
|
median_survivals = {}
|
|
|
|
# Plot each group
|
|
for i, group in enumerate(groups):
|
|
group_data = data[data[group_col] == group]
|
|
|
|
# Fit Kaplan-Meier
|
|
kmf = KaplanMeierFitter()
|
|
kmf.fit(group_data[time_col], group_data[event_col], label=str(group))
|
|
|
|
# Plot survival curve
|
|
kmf.plot_survival_function(ax=ax, ci_show=True, color=colors[i % len(colors)],
|
|
linewidth=2, alpha=0.8)
|
|
|
|
# Store model
|
|
kmf_models[group] = kmf
|
|
|
|
# Calculate median survival
|
|
median, lower, upper = calculate_median_survival(kmf)
|
|
median_survivals[group] = (median, lower, upper)
|
|
|
|
# Log-rank test
|
|
if len(groups) == 2:
|
|
group1_data = data[data[group_col] == groups[0]]
|
|
group2_data = data[data[group_col] == groups[1]]
|
|
|
|
results = logrank_test(
|
|
group1_data[time_col], group2_data[time_col],
|
|
group1_data[event_col], group2_data[event_col]
|
|
)
|
|
|
|
p_value = results.p_value
|
|
test_statistic = results.test_statistic
|
|
|
|
# Add log-rank test result to plot
|
|
ax.text(0.02, 0.15, f'Log-rank test:\np = {p_value:.4f}',
|
|
transform=ax.transAxes, fontsize=10,
|
|
verticalalignment='top',
|
|
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
|
else:
|
|
# Multivariate log-rank for >2 groups
|
|
results = multivariate_logrank_test(data[time_col], data[group_col], data[event_col])
|
|
p_value = results.p_value
|
|
test_statistic = results.test_statistic
|
|
|
|
ax.text(0.02, 0.15, f'Log-rank test:\np = {p_value:.4f}\n({len(groups)} groups)',
|
|
transform=ax.transAxes, fontsize=10,
|
|
verticalalignment='top',
|
|
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
|
|
|
# Add median survival annotations
|
|
y_pos = 0.95
|
|
for group, (median, lower, upper) in median_survivals.items():
|
|
if median is not None:
|
|
ax.text(0.98, y_pos, f'{group}: {median:.1f} months (95% CI {lower:.1f}-{upper:.1f})',
|
|
transform=ax.transAxes, fontsize=9, ha='right',
|
|
verticalalignment='top')
|
|
else:
|
|
ax.text(0.98, y_pos, f'{group}: Not reached',
|
|
transform=ax.transAxes, fontsize=9, ha='right',
|
|
verticalalignment='top')
|
|
y_pos -= 0.05
|
|
|
|
# Formatting
|
|
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
|
|
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
|
|
ax.set_title(title, fontsize=14, fontweight='bold', pad=15)
|
|
ax.legend(loc='lower left', frameon=True, fontsize=10)
|
|
ax.grid(True, alpha=0.3, linestyle='--')
|
|
ax.set_ylim([0, 1.05])
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
print(f"Survival curve saved to: {output_path}")
|
|
|
|
# Also save as PNG for easy viewing
|
|
png_path = Path(output_path).with_suffix('.png')
|
|
plt.savefig(png_path, dpi=300, bbox_inches='tight')
|
|
print(f"PNG version saved to: {png_path}")
|
|
|
|
plt.close()
|
|
|
|
return kmf_models, p_value
|
|
|
|
|
|
def generate_number_at_risk_table(data, time_col='time', event_col='event',
|
|
group_col='group', time_points=None):
|
|
"""
|
|
Generate number at risk table for survival analysis.
|
|
|
|
Parameters:
|
|
data: DataFrame with survival data
|
|
time_points: List of time points for risk table (if None, auto-generate)
|
|
|
|
Returns:
|
|
DataFrame with number at risk at each time point
|
|
"""
|
|
|
|
if time_points is None:
|
|
# Auto-generate time points (every 6 months up to max time)
|
|
max_time = data[time_col].max()
|
|
time_points = np.arange(0, max_time + 6, 6)
|
|
|
|
groups = data[group_col].unique()
|
|
risk_table = pd.DataFrame(index=time_points, columns=groups)
|
|
|
|
for group in groups:
|
|
group_data = data[data[group_col] == group]
|
|
|
|
for t in time_points:
|
|
# Number at risk = patients who haven't had event and haven't been censored before time t
|
|
at_risk = len(group_data[group_data[time_col] >= t])
|
|
risk_table.loc[t, group] = at_risk
|
|
|
|
return risk_table
|
|
|
|
|
|
def calculate_hazard_ratio(data, time_col='time', event_col='event', group_col='group',
|
|
reference_group=None):
|
|
"""
|
|
Calculate hazard ratio using Cox proportional hazards regression.
|
|
|
|
Parameters:
|
|
data: DataFrame
|
|
reference_group: Reference group for comparison (if None, uses first group)
|
|
|
|
Returns:
|
|
Hazard ratio, 95% CI, p-value
|
|
"""
|
|
|
|
# Encode group as binary for Cox regression
|
|
groups = data[group_col].unique()
|
|
if len(groups) != 2:
|
|
print("Warning: Cox HR calculation assumes 2 groups. Using first 2 groups.")
|
|
groups = groups[:2]
|
|
|
|
if reference_group is None:
|
|
reference_group = groups[0]
|
|
|
|
# Create binary indicator (1 for comparison group, 0 for reference)
|
|
data_cox = data.copy()
|
|
data_cox['group_binary'] = (data_cox[group_col] != reference_group).astype(int)
|
|
|
|
# Fit Cox model
|
|
cph = CoxPHFitter()
|
|
cph.fit(data_cox[[time_col, event_col, 'group_binary']],
|
|
duration_col=time_col, event_col=event_col)
|
|
|
|
# Extract results
|
|
hr = np.exp(cph.params_['group_binary'])
|
|
ci = np.exp(cph.confidence_intervals_.loc['group_binary'].values)
|
|
p_value = cph.summary.loc['group_binary', 'p']
|
|
|
|
return hr, ci[0], ci[1], p_value
|
|
|
|
|
|
def generate_report(data, output_dir, prefix='survival'):
|
|
"""
|
|
Generate comprehensive survival analysis report.
|
|
|
|
Creates:
|
|
- Kaplan-Meier curves (PDF and PNG)
|
|
- Number at risk table (CSV)
|
|
- Statistical summary (TXT)
|
|
- LaTeX table code (TEX)
|
|
"""
|
|
|
|
output_dir = Path(output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Generate survival curve
|
|
kmf_models, logrank_p = generate_kaplan_meier_plot(
|
|
data,
|
|
output_path=output_dir / f'{prefix}_kaplan_meier.pdf',
|
|
title='Survival Analysis by Group'
|
|
)
|
|
|
|
# Number at risk table
|
|
risk_table = generate_number_at_risk_table(data)
|
|
risk_table.to_csv(output_dir / f'{prefix}_number_at_risk.csv')
|
|
|
|
# Calculate hazard ratio
|
|
hr, ci_lower, ci_upper, hr_p = calculate_hazard_ratio(data)
|
|
|
|
# Generate statistical summary
|
|
with open(output_dir / f'{prefix}_statistics.txt', 'w') as f:
|
|
f.write("SURVIVAL ANALYSIS STATISTICAL SUMMARY\n")
|
|
f.write("=" * 60 + "\n\n")
|
|
|
|
groups = data['group'].unique()
|
|
for group in groups:
|
|
kmf = kmf_models[group]
|
|
median = kmf.median_survival_time_
|
|
|
|
# Calculate survival rates at common time points
|
|
try:
|
|
surv_12m = kmf.survival_function_at_times(12).values[0]
|
|
surv_24m = kmf.survival_function_at_times(24).values[0] if data['time'].max() >= 24 else None
|
|
except:
|
|
surv_12m = None
|
|
surv_24m = None
|
|
|
|
f.write(f"Group: {group}\n")
|
|
f.write(f" N = {len(data[data['group'] == group])}\n")
|
|
f.write(f" Events = {data[data['group'] == group]['event'].sum()}\n")
|
|
f.write(f" Median survival: {median:.1f} months\n" if median != np.inf else " Median survival: Not reached\n")
|
|
if surv_12m is not None:
|
|
f.write(f" 12-month survival rate: {surv_12m*100:.1f}%\n")
|
|
if surv_24m is not None:
|
|
f.write(f" 24-month survival rate: {surv_24m*100:.1f}%\n")
|
|
f.write("\n")
|
|
|
|
f.write(f"Log-Rank Test:\n")
|
|
f.write(f" p-value = {logrank_p:.4f}\n")
|
|
f.write(f" Interpretation: {'Significant' if logrank_p < 0.05 else 'Not significant'} difference in survival\n\n")
|
|
|
|
if len(groups) == 2:
|
|
f.write(f"Hazard Ratio ({groups[1]} vs {groups[0]}):\n")
|
|
f.write(f" HR = {hr:.2f} (95% CI {ci_lower:.2f}-{ci_upper:.2f})\n")
|
|
f.write(f" p-value = {hr_p:.4f}\n")
|
|
f.write(f" Interpretation: {groups[1]} has {((1-hr)*100):.0f}% {'reduction' if hr < 1 else 'increase'} in risk\n")
|
|
|
|
# Generate LaTeX table code
|
|
with open(output_dir / f'{prefix}_latex_table.tex', 'w') as f:
|
|
f.write("% LaTeX table code for survival outcomes\n")
|
|
f.write("\\begin{table}[H]\n")
|
|
f.write("\\centering\n")
|
|
f.write("\\small\n")
|
|
f.write("\\begin{tabular}{lcccc}\n")
|
|
f.write("\\toprule\n")
|
|
f.write("\\textbf{Endpoint} & \\textbf{Group A} & \\textbf{Group B} & \\textbf{HR (95\\% CI)} & \\textbf{p-value} \\\\\n")
|
|
f.write("\\midrule\n")
|
|
|
|
# Add median survival row
|
|
for i, group in enumerate(groups):
|
|
kmf = kmf_models[group]
|
|
median = kmf.median_survival_time_
|
|
if i == 0:
|
|
f.write(f"Median survival, months (95\\% CI) & ")
|
|
if median != np.inf:
|
|
f.write(f"{median:.1f} & ")
|
|
else:
|
|
f.write("NR & ")
|
|
else:
|
|
if median != np.inf:
|
|
f.write(f"{median:.1f} & ")
|
|
else:
|
|
f.write("NR & ")
|
|
|
|
f.write(f"{hr:.2f} ({ci_lower:.2f}-{ci_upper:.2f}) & {hr_p:.3f} \\\\\n")
|
|
|
|
# Add 12-month survival rate
|
|
f.write("12-month survival rate (\\%) & ")
|
|
for group in groups:
|
|
kmf = kmf_models[group]
|
|
try:
|
|
surv_12m = kmf.survival_function_at_times(12).values[0]
|
|
f.write(f"{surv_12m*100:.0f}\\% & ")
|
|
except:
|
|
f.write("-- & ")
|
|
f.write("-- & -- \\\\\n")
|
|
|
|
f.write("\\bottomrule\n")
|
|
f.write("\\end{tabular}\n")
|
|
f.write(f"\\caption{{Survival outcomes by group (log-rank p={logrank_p:.3f})}}\n")
|
|
f.write("\\end{table}\n")
|
|
|
|
print(f"\nAnalysis complete! Files saved to {output_dir}/")
|
|
print(f" - Survival curves: {prefix}_kaplan_meier.pdf/png")
|
|
print(f" - Statistics: {prefix}_statistics.txt")
|
|
print(f" - LaTeX table: {prefix}_latex_table.tex")
|
|
print(f" - Risk table: {prefix}_number_at_risk.csv")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Generate Kaplan-Meier survival curves')
|
|
parser.add_argument('input_file', type=str, help='CSV file with survival data')
|
|
parser.add_argument('-o', '--output', type=str, default='survival_output',
|
|
help='Output directory (default: survival_output)')
|
|
parser.add_argument('-t', '--title', type=str, default='Kaplan-Meier Survival Curve',
|
|
help='Plot title')
|
|
parser.add_argument('-x', '--xlabel', type=str, default='Time (months)',
|
|
help='X-axis label')
|
|
parser.add_argument('-y', '--ylabel', type=str, default='Survival Probability',
|
|
help='Y-axis label')
|
|
parser.add_argument('--time-col', type=str, default='time',
|
|
help='Column name for time variable')
|
|
parser.add_argument('--event-col', type=str, default='event',
|
|
help='Column name for event indicator')
|
|
parser.add_argument('--group-col', type=str, default='group',
|
|
help='Column name for grouping variable')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load data
|
|
print(f"Loading data from {args.input_file}...")
|
|
data = load_survival_data(args.input_file)
|
|
print(f"Loaded {len(data)} patients")
|
|
print(f"Groups: {data[args.group_col].value_counts().to_dict()}")
|
|
|
|
# Generate analysis
|
|
generate_report(
|
|
data,
|
|
output_dir=args.output,
|
|
prefix='survival'
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|
|
|
|
# Example usage:
|
|
# python generate_survival_analysis.py survival_data.csv -o figures/ -t "PFS by PD-L1 Status"
|
|
#
|
|
# Input CSV format:
|
|
# patient_id,time,event,group
|
|
# PT001,12.3,1,PD-L1+
|
|
# PT002,8.5,1,PD-L1-
|
|
# PT003,18.2,0,PD-L1+
|
|
# ...
|
|
|