Files
gh-k-dense-ai-claude-scient…/skills/scikit-survival/references/ensemble-models.md
2025-11-30 08:30:10 +08:00

328 lines
9.3 KiB
Markdown

# Ensemble Models for Survival Analysis
## Random Survival Forests
### Overview
Random Survival Forests extend the random forest algorithm to survival analysis with censored data. They build multiple decision trees on bootstrap samples and aggregate predictions.
### How They Work
1. **Bootstrap Sampling**: Each tree is built on a different bootstrap sample of the training data
2. **Feature Randomness**: At each node, only a random subset of features is considered for splitting
3. **Survival Function Estimation**: At terminal nodes, Kaplan-Meier and Nelson-Aalen estimators compute survival functions
4. **Ensemble Aggregation**: Final predictions average survival functions across all trees
### When to Use
- Complex non-linear relationships between features and survival
- No assumptions about functional form needed
- Want robust predictions with minimal tuning
- Need feature importance estimates
- Have sufficient sample size (typically n > 100)
### Key Parameters
- `n_estimators`: Number of trees (default: 100)
- More trees = more stable predictions but slower
- Typical range: 100-1000
- `max_depth`: Maximum depth of trees
- Controls tree complexity
- None = nodes expanded until pure or min_samples_split
- `min_samples_split`: Minimum samples to split a node (default: 6)
- Larger values = more regularization
- `min_samples_leaf`: Minimum samples at leaf nodes (default: 3)
- Prevents overfitting to small groups
- `max_features`: Number of features to consider at each split
- 'sqrt': sqrt(n_features) - good default
- 'log2': log2(n_features)
- None: all features
- `n_jobs`: Number of parallel jobs (-1 uses all processors)
### Example Usage
```python
from sksurv.ensemble import RandomSurvivalForest
from sksurv.datasets import load_breast_cancer
# Load data
X, y = load_breast_cancer()
# Fit Random Survival Forest
rsf = RandomSurvivalForest(n_estimators=1000,
min_samples_split=10,
min_samples_leaf=15,
max_features="sqrt",
n_jobs=-1,
random_state=42)
rsf.fit(X, y)
# Predict risk scores
risk_scores = rsf.predict(X)
# Predict survival functions
surv_funcs = rsf.predict_survival_function(X)
# Predict cumulative hazard functions
chf_funcs = rsf.predict_cumulative_hazard_function(X)
```
### Feature Importance
**Important**: Built-in feature importance based on split impurity is not reliable for survival data. Use permutation-based feature importance instead.
```python
from sklearn.inspection import permutation_importance
from sksurv.metrics import concordance_index_censored
# Define scoring function
def score_survival_model(model, X, y):
prediction = model.predict(X)
result = concordance_index_censored(y['event'], y['time'], prediction)
return result[0]
# Compute permutation importance
perm_importance = permutation_importance(
rsf, X, y,
n_repeats=10,
random_state=42,
scoring=score_survival_model
)
# Get feature importance
feature_importance = perm_importance.importances_mean
```
## Gradient Boosting Survival Analysis
### Overview
Gradient boosting builds an ensemble by sequentially adding weak learners that correct errors of previous learners. The model is: **f(x) = Σ β_m g(x; θ_m)**
### Model Types
#### GradientBoostingSurvivalAnalysis
Uses regression trees as base learners. Can capture complex non-linear relationships.
**When to Use:**
- Need to model complex non-linear relationships
- Want high predictive performance
- Have sufficient data to avoid overfitting
- Can tune hyperparameters carefully
#### ComponentwiseGradientBoostingSurvivalAnalysis
Uses component-wise least squares as base learners. Produces linear models with automatic feature selection.
**When to Use:**
- Want interpretable linear model
- Need automatic feature selection (like Lasso)
- Have high-dimensional data
- Prefer sparse models
### Loss Functions
#### Cox's Partial Likelihood (default)
Maintains proportional hazards framework but replaces linear model with additive ensemble model.
**Appropriate for:**
- Standard survival analysis settings
- When proportional hazards is reasonable
- Most use cases
#### Accelerated Failure Time (AFT)
Assumes features accelerate or decelerate survival time by a constant factor. Loss function: **(1/n) Σ ω_i (log y_i - f(x_i))²**
**Appropriate for:**
- AFT framework preferred over proportional hazards
- Want to model time directly
- Need to interpret effects on survival time
### Regularization Strategies
Three main techniques prevent overfitting:
1. **Learning Rate** (`learning_rate < 1`)
- Shrinks contribution of each base learner
- Smaller values need more iterations but better generalization
- Typical range: 0.01 - 0.1
2. **Dropout** (`dropout_rate > 0`)
- Randomly drops previous learners during training
- Forces learners to be more robust
- Typical range: 0.01 - 0.2
3. **Subsampling** (`subsample < 1`)
- Uses random subset of data for each iteration
- Adds randomness and reduces overfitting
- Typical range: 0.5 - 0.9
**Recommendation**: Combine small learning rate with early stopping for best performance.
### Key Parameters
- `loss`: Loss function ('coxph' or 'ipcwls')
- `learning_rate`: Shrinks contribution of each tree (default: 0.1)
- `n_estimators`: Number of boosting iterations (default: 100)
- `subsample`: Fraction of samples for each iteration (default: 1.0)
- `dropout_rate`: Dropout rate for learners (default: 0.0)
- `max_depth`: Maximum depth of trees (default: 3)
- `min_samples_split`: Minimum samples to split node (default: 2)
- `min_samples_leaf`: Minimum samples at leaf (default: 1)
- `max_features`: Features to consider at each split
### Example Usage
```python
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sklearn.model_selection import train_test_split
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Fit gradient boosting model
gbs = GradientBoostingSurvivalAnalysis(
loss='coxph',
learning_rate=0.05,
n_estimators=200,
subsample=0.8,
dropout_rate=0.1,
max_depth=3,
random_state=42
)
gbs.fit(X_train, y_train)
# Predict risk scores
risk_scores = gbs.predict(X_test)
# Predict survival functions
surv_funcs = gbs.predict_survival_function(X_test)
# Predict cumulative hazard functions
chf_funcs = gbs.predict_cumulative_hazard_function(X_test)
```
### Early Stopping
Use validation set to prevent overfitting:
```python
from sklearn.model_selection import train_test_split
# Create train/validation split
X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
# Fit with early stopping
gbs = GradientBoostingSurvivalAnalysis(
n_estimators=1000,
learning_rate=0.01,
max_depth=3,
validation_fraction=0.2,
n_iter_no_change=10,
random_state=42
)
gbs.fit(X_tr, y_tr)
# Number of iterations used
print(f"Used {gbs.n_estimators_} iterations")
```
### Hyperparameter Tuning
```python
from sklearn.model_selection import GridSearchCV
param_grid = {
'learning_rate': [0.01, 0.05, 0.1],
'n_estimators': [100, 200, 300],
'max_depth': [3, 5, 7],
'subsample': [0.8, 1.0]
}
cv = GridSearchCV(
GradientBoostingSurvivalAnalysis(),
param_grid,
scoring='concordance_index_ipcw',
cv=5,
n_jobs=-1
)
cv.fit(X, y)
best_model = cv.best_estimator_
```
## ComponentwiseGradientBoostingSurvivalAnalysis
### Overview
Uses component-wise least squares, producing sparse linear models with automatic feature selection similar to Lasso.
### When to Use
- Want interpretable linear model
- Need automatic feature selection
- Have high-dimensional data with many irrelevant features
- Prefer coefficient-based interpretation
### Example Usage
```python
from sksurv.ensemble import ComponentwiseGradientBoostingSurvivalAnalysis
# Fit componentwise boosting
cgbs = ComponentwiseGradientBoostingSurvivalAnalysis(
loss='coxph',
learning_rate=0.1,
n_estimators=100
)
cgbs.fit(X, y)
# Get selected features and coefficients
coef = cgbs.coef_
selected_features = [i for i, c in enumerate(coef) if c != 0]
```
## ExtraSurvivalTrees
Extremely randomized survival trees - similar to Random Survival Forest but with additional randomness in split selection.
### When to Use
- Want even more regularization than Random Survival Forest
- Have limited data
- Need faster training
### Key Difference
Instead of finding the best split for selected features, it randomly selects split points, adding more diversity to the ensemble.
```python
from sksurv.ensemble import ExtraSurvivalTrees
est = ExtraSurvivalTrees(n_estimators=100, random_state=42)
est.fit(X, y)
```
## Model Comparison
| Model | Complexity | Interpretability | Performance | Speed |
|-------|-----------|------------------|-------------|-------|
| Random Survival Forest | Medium | Low | High | Medium |
| GradientBoostingSurvivalAnalysis | High | Low | Highest | Slow |
| ComponentwiseGradientBoostingSurvivalAnalysis | Low | High | Medium | Fast |
| ExtraSurvivalTrees | Medium | Low | Medium-High | Fast |
**General Recommendations:**
- **Best overall performance**: GradientBoostingSurvivalAnalysis with tuning
- **Best balance**: RandomSurvivalForest
- **Best interpretability**: ComponentwiseGradientBoostingSurvivalAnalysis
- **Fastest training**: ExtraSurvivalTrees