9.3 KiB
Ensemble Models for Survival Analysis
Random Survival Forests
Overview
Random Survival Forests extend the random forest algorithm to survival analysis with censored data. They build multiple decision trees on bootstrap samples and aggregate predictions.
How They Work
- Bootstrap Sampling: Each tree is built on a different bootstrap sample of the training data
- Feature Randomness: At each node, only a random subset of features is considered for splitting
- Survival Function Estimation: At terminal nodes, Kaplan-Meier and Nelson-Aalen estimators compute survival functions
- Ensemble Aggregation: Final predictions average survival functions across all trees
When to Use
- Complex non-linear relationships between features and survival
- No assumptions about functional form needed
- Want robust predictions with minimal tuning
- Need feature importance estimates
- Have sufficient sample size (typically n > 100)
Key Parameters
-
n_estimators: Number of trees (default: 100)- More trees = more stable predictions but slower
- Typical range: 100-1000
-
max_depth: Maximum depth of trees- Controls tree complexity
- None = nodes expanded until pure or min_samples_split
-
min_samples_split: Minimum samples to split a node (default: 6)- Larger values = more regularization
-
min_samples_leaf: Minimum samples at leaf nodes (default: 3)- Prevents overfitting to small groups
-
max_features: Number of features to consider at each split- 'sqrt': sqrt(n_features) - good default
- 'log2': log2(n_features)
- None: all features
-
n_jobs: Number of parallel jobs (-1 uses all processors)
Example Usage
from sksurv.ensemble import RandomSurvivalForest
from sksurv.datasets import load_breast_cancer
# Load data
X, y = load_breast_cancer()
# Fit Random Survival Forest
rsf = RandomSurvivalForest(n_estimators=1000,
min_samples_split=10,
min_samples_leaf=15,
max_features="sqrt",
n_jobs=-1,
random_state=42)
rsf.fit(X, y)
# Predict risk scores
risk_scores = rsf.predict(X)
# Predict survival functions
surv_funcs = rsf.predict_survival_function(X)
# Predict cumulative hazard functions
chf_funcs = rsf.predict_cumulative_hazard_function(X)
Feature Importance
Important: Built-in feature importance based on split impurity is not reliable for survival data. Use permutation-based feature importance instead.
from sklearn.inspection import permutation_importance
from sksurv.metrics import concordance_index_censored
# Define scoring function
def score_survival_model(model, X, y):
prediction = model.predict(X)
result = concordance_index_censored(y['event'], y['time'], prediction)
return result[0]
# Compute permutation importance
perm_importance = permutation_importance(
rsf, X, y,
n_repeats=10,
random_state=42,
scoring=score_survival_model
)
# Get feature importance
feature_importance = perm_importance.importances_mean
Gradient Boosting Survival Analysis
Overview
Gradient boosting builds an ensemble by sequentially adding weak learners that correct errors of previous learners. The model is: f(x) = Σ β_m g(x; θ_m)
Model Types
GradientBoostingSurvivalAnalysis
Uses regression trees as base learners. Can capture complex non-linear relationships.
When to Use:
- Need to model complex non-linear relationships
- Want high predictive performance
- Have sufficient data to avoid overfitting
- Can tune hyperparameters carefully
ComponentwiseGradientBoostingSurvivalAnalysis
Uses component-wise least squares as base learners. Produces linear models with automatic feature selection.
When to Use:
- Want interpretable linear model
- Need automatic feature selection (like Lasso)
- Have high-dimensional data
- Prefer sparse models
Loss Functions
Cox's Partial Likelihood (default)
Maintains proportional hazards framework but replaces linear model with additive ensemble model.
Appropriate for:
- Standard survival analysis settings
- When proportional hazards is reasonable
- Most use cases
Accelerated Failure Time (AFT)
Assumes features accelerate or decelerate survival time by a constant factor. Loss function: (1/n) Σ ω_i (log y_i - f(x_i))²
Appropriate for:
- AFT framework preferred over proportional hazards
- Want to model time directly
- Need to interpret effects on survival time
Regularization Strategies
Three main techniques prevent overfitting:
-
Learning Rate (
learning_rate < 1)- Shrinks contribution of each base learner
- Smaller values need more iterations but better generalization
- Typical range: 0.01 - 0.1
-
Dropout (
dropout_rate > 0)- Randomly drops previous learners during training
- Forces learners to be more robust
- Typical range: 0.01 - 0.2
-
Subsampling (
subsample < 1)- Uses random subset of data for each iteration
- Adds randomness and reduces overfitting
- Typical range: 0.5 - 0.9
Recommendation: Combine small learning rate with early stopping for best performance.
Key Parameters
loss: Loss function ('coxph' or 'ipcwls')learning_rate: Shrinks contribution of each tree (default: 0.1)n_estimators: Number of boosting iterations (default: 100)subsample: Fraction of samples for each iteration (default: 1.0)dropout_rate: Dropout rate for learners (default: 0.0)max_depth: Maximum depth of trees (default: 3)min_samples_split: Minimum samples to split node (default: 2)min_samples_leaf: Minimum samples at leaf (default: 1)max_features: Features to consider at each split
Example Usage
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sklearn.model_selection import train_test_split
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Fit gradient boosting model
gbs = GradientBoostingSurvivalAnalysis(
loss='coxph',
learning_rate=0.05,
n_estimators=200,
subsample=0.8,
dropout_rate=0.1,
max_depth=3,
random_state=42
)
gbs.fit(X_train, y_train)
# Predict risk scores
risk_scores = gbs.predict(X_test)
# Predict survival functions
surv_funcs = gbs.predict_survival_function(X_test)
# Predict cumulative hazard functions
chf_funcs = gbs.predict_cumulative_hazard_function(X_test)
Early Stopping
Use validation set to prevent overfitting:
from sklearn.model_selection import train_test_split
# Create train/validation split
X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
# Fit with early stopping
gbs = GradientBoostingSurvivalAnalysis(
n_estimators=1000,
learning_rate=0.01,
max_depth=3,
validation_fraction=0.2,
n_iter_no_change=10,
random_state=42
)
gbs.fit(X_tr, y_tr)
# Number of iterations used
print(f"Used {gbs.n_estimators_} iterations")
Hyperparameter Tuning
from sklearn.model_selection import GridSearchCV
param_grid = {
'learning_rate': [0.01, 0.05, 0.1],
'n_estimators': [100, 200, 300],
'max_depth': [3, 5, 7],
'subsample': [0.8, 1.0]
}
cv = GridSearchCV(
GradientBoostingSurvivalAnalysis(),
param_grid,
scoring='concordance_index_ipcw',
cv=5,
n_jobs=-1
)
cv.fit(X, y)
best_model = cv.best_estimator_
ComponentwiseGradientBoostingSurvivalAnalysis
Overview
Uses component-wise least squares, producing sparse linear models with automatic feature selection similar to Lasso.
When to Use
- Want interpretable linear model
- Need automatic feature selection
- Have high-dimensional data with many irrelevant features
- Prefer coefficient-based interpretation
Example Usage
from sksurv.ensemble import ComponentwiseGradientBoostingSurvivalAnalysis
# Fit componentwise boosting
cgbs = ComponentwiseGradientBoostingSurvivalAnalysis(
loss='coxph',
learning_rate=0.1,
n_estimators=100
)
cgbs.fit(X, y)
# Get selected features and coefficients
coef = cgbs.coef_
selected_features = [i for i, c in enumerate(coef) if c != 0]
ExtraSurvivalTrees
Extremely randomized survival trees - similar to Random Survival Forest but with additional randomness in split selection.
When to Use
- Want even more regularization than Random Survival Forest
- Have limited data
- Need faster training
Key Difference
Instead of finding the best split for selected features, it randomly selects split points, adding more diversity to the ensemble.
from sksurv.ensemble import ExtraSurvivalTrees
est = ExtraSurvivalTrees(n_estimators=100, random_state=42)
est.fit(X, y)
Model Comparison
| Model | Complexity | Interpretability | Performance | Speed |
|---|---|---|---|---|
| Random Survival Forest | Medium | Low | High | Medium |
| GradientBoostingSurvivalAnalysis | High | Low | Highest | Slow |
| ComponentwiseGradientBoostingSurvivalAnalysis | Low | High | Medium | Fast |
| ExtraSurvivalTrees | Medium | Low | Medium-High | Fast |
General Recommendations:
- Best overall performance: GradientBoostingSurvivalAnalysis with tuning
- Best balance: RandomSurvivalForest
- Best interpretability: ComponentwiseGradientBoostingSurvivalAnalysis
- Fastest training: ExtraSurvivalTrees