Initial commit
This commit is contained in:
327
skills/scikit-survival/references/ensemble-models.md
Normal file
327
skills/scikit-survival/references/ensemble-models.md
Normal file
@@ -0,0 +1,327 @@
|
||||
# Ensemble Models for Survival Analysis
|
||||
|
||||
## Random Survival Forests
|
||||
|
||||
### Overview
|
||||
|
||||
Random Survival Forests extend the random forest algorithm to survival analysis with censored data. They build multiple decision trees on bootstrap samples and aggregate predictions.
|
||||
|
||||
### How They Work
|
||||
|
||||
1. **Bootstrap Sampling**: Each tree is built on a different bootstrap sample of the training data
|
||||
2. **Feature Randomness**: At each node, only a random subset of features is considered for splitting
|
||||
3. **Survival Function Estimation**: At terminal nodes, Kaplan-Meier and Nelson-Aalen estimators compute survival functions
|
||||
4. **Ensemble Aggregation**: Final predictions average survival functions across all trees
|
||||
|
||||
### When to Use
|
||||
|
||||
- Complex non-linear relationships between features and survival
|
||||
- No assumptions about functional form needed
|
||||
- Want robust predictions with minimal tuning
|
||||
- Need feature importance estimates
|
||||
- Have sufficient sample size (typically n > 100)
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `n_estimators`: Number of trees (default: 100)
|
||||
- More trees = more stable predictions but slower
|
||||
- Typical range: 100-1000
|
||||
|
||||
- `max_depth`: Maximum depth of trees
|
||||
- Controls tree complexity
|
||||
- None = nodes expanded until pure or min_samples_split
|
||||
|
||||
- `min_samples_split`: Minimum samples to split a node (default: 6)
|
||||
- Larger values = more regularization
|
||||
|
||||
- `min_samples_leaf`: Minimum samples at leaf nodes (default: 3)
|
||||
- Prevents overfitting to small groups
|
||||
|
||||
- `max_features`: Number of features to consider at each split
|
||||
- 'sqrt': sqrt(n_features) - good default
|
||||
- 'log2': log2(n_features)
|
||||
- None: all features
|
||||
|
||||
- `n_jobs`: Number of parallel jobs (-1 uses all processors)
|
||||
|
||||
### Example Usage
|
||||
|
||||
```python
|
||||
from sksurv.ensemble import RandomSurvivalForest
|
||||
from sksurv.datasets import load_breast_cancer
|
||||
|
||||
# Load data
|
||||
X, y = load_breast_cancer()
|
||||
|
||||
# Fit Random Survival Forest
|
||||
rsf = RandomSurvivalForest(n_estimators=1000,
|
||||
min_samples_split=10,
|
||||
min_samples_leaf=15,
|
||||
max_features="sqrt",
|
||||
n_jobs=-1,
|
||||
random_state=42)
|
||||
rsf.fit(X, y)
|
||||
|
||||
# Predict risk scores
|
||||
risk_scores = rsf.predict(X)
|
||||
|
||||
# Predict survival functions
|
||||
surv_funcs = rsf.predict_survival_function(X)
|
||||
|
||||
# Predict cumulative hazard functions
|
||||
chf_funcs = rsf.predict_cumulative_hazard_function(X)
|
||||
```
|
||||
|
||||
### Feature Importance
|
||||
|
||||
**Important**: Built-in feature importance based on split impurity is not reliable for survival data. Use permutation-based feature importance instead.
|
||||
|
||||
```python
|
||||
from sklearn.inspection import permutation_importance
|
||||
from sksurv.metrics import concordance_index_censored
|
||||
|
||||
# Define scoring function
|
||||
def score_survival_model(model, X, y):
|
||||
prediction = model.predict(X)
|
||||
result = concordance_index_censored(y['event'], y['time'], prediction)
|
||||
return result[0]
|
||||
|
||||
# Compute permutation importance
|
||||
perm_importance = permutation_importance(
|
||||
rsf, X, y,
|
||||
n_repeats=10,
|
||||
random_state=42,
|
||||
scoring=score_survival_model
|
||||
)
|
||||
|
||||
# Get feature importance
|
||||
feature_importance = perm_importance.importances_mean
|
||||
```
|
||||
|
||||
## Gradient Boosting Survival Analysis
|
||||
|
||||
### Overview
|
||||
|
||||
Gradient boosting builds an ensemble by sequentially adding weak learners that correct errors of previous learners. The model is: **f(x) = Σ β_m g(x; θ_m)**
|
||||
|
||||
### Model Types
|
||||
|
||||
#### GradientBoostingSurvivalAnalysis
|
||||
|
||||
Uses regression trees as base learners. Can capture complex non-linear relationships.
|
||||
|
||||
**When to Use:**
|
||||
- Need to model complex non-linear relationships
|
||||
- Want high predictive performance
|
||||
- Have sufficient data to avoid overfitting
|
||||
- Can tune hyperparameters carefully
|
||||
|
||||
#### ComponentwiseGradientBoostingSurvivalAnalysis
|
||||
|
||||
Uses component-wise least squares as base learners. Produces linear models with automatic feature selection.
|
||||
|
||||
**When to Use:**
|
||||
- Want interpretable linear model
|
||||
- Need automatic feature selection (like Lasso)
|
||||
- Have high-dimensional data
|
||||
- Prefer sparse models
|
||||
|
||||
### Loss Functions
|
||||
|
||||
#### Cox's Partial Likelihood (default)
|
||||
|
||||
Maintains proportional hazards framework but replaces linear model with additive ensemble model.
|
||||
|
||||
**Appropriate for:**
|
||||
- Standard survival analysis settings
|
||||
- When proportional hazards is reasonable
|
||||
- Most use cases
|
||||
|
||||
#### Accelerated Failure Time (AFT)
|
||||
|
||||
Assumes features accelerate or decelerate survival time by a constant factor. Loss function: **(1/n) Σ ω_i (log y_i - f(x_i))²**
|
||||
|
||||
**Appropriate for:**
|
||||
- AFT framework preferred over proportional hazards
|
||||
- Want to model time directly
|
||||
- Need to interpret effects on survival time
|
||||
|
||||
### Regularization Strategies
|
||||
|
||||
Three main techniques prevent overfitting:
|
||||
|
||||
1. **Learning Rate** (`learning_rate < 1`)
|
||||
- Shrinks contribution of each base learner
|
||||
- Smaller values need more iterations but better generalization
|
||||
- Typical range: 0.01 - 0.1
|
||||
|
||||
2. **Dropout** (`dropout_rate > 0`)
|
||||
- Randomly drops previous learners during training
|
||||
- Forces learners to be more robust
|
||||
- Typical range: 0.01 - 0.2
|
||||
|
||||
3. **Subsampling** (`subsample < 1`)
|
||||
- Uses random subset of data for each iteration
|
||||
- Adds randomness and reduces overfitting
|
||||
- Typical range: 0.5 - 0.9
|
||||
|
||||
**Recommendation**: Combine small learning rate with early stopping for best performance.
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `loss`: Loss function ('coxph' or 'ipcwls')
|
||||
- `learning_rate`: Shrinks contribution of each tree (default: 0.1)
|
||||
- `n_estimators`: Number of boosting iterations (default: 100)
|
||||
- `subsample`: Fraction of samples for each iteration (default: 1.0)
|
||||
- `dropout_rate`: Dropout rate for learners (default: 0.0)
|
||||
- `max_depth`: Maximum depth of trees (default: 3)
|
||||
- `min_samples_split`: Minimum samples to split node (default: 2)
|
||||
- `min_samples_leaf`: Minimum samples at leaf (default: 1)
|
||||
- `max_features`: Features to consider at each split
|
||||
|
||||
### Example Usage
|
||||
|
||||
```python
|
||||
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# Split data
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
# Fit gradient boosting model
|
||||
gbs = GradientBoostingSurvivalAnalysis(
|
||||
loss='coxph',
|
||||
learning_rate=0.05,
|
||||
n_estimators=200,
|
||||
subsample=0.8,
|
||||
dropout_rate=0.1,
|
||||
max_depth=3,
|
||||
random_state=42
|
||||
)
|
||||
gbs.fit(X_train, y_train)
|
||||
|
||||
# Predict risk scores
|
||||
risk_scores = gbs.predict(X_test)
|
||||
|
||||
# Predict survival functions
|
||||
surv_funcs = gbs.predict_survival_function(X_test)
|
||||
|
||||
# Predict cumulative hazard functions
|
||||
chf_funcs = gbs.predict_cumulative_hazard_function(X_test)
|
||||
```
|
||||
|
||||
### Early Stopping
|
||||
|
||||
Use validation set to prevent overfitting:
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# Create train/validation split
|
||||
X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
|
||||
|
||||
# Fit with early stopping
|
||||
gbs = GradientBoostingSurvivalAnalysis(
|
||||
n_estimators=1000,
|
||||
learning_rate=0.01,
|
||||
max_depth=3,
|
||||
validation_fraction=0.2,
|
||||
n_iter_no_change=10,
|
||||
random_state=42
|
||||
)
|
||||
gbs.fit(X_tr, y_tr)
|
||||
|
||||
# Number of iterations used
|
||||
print(f"Used {gbs.n_estimators_} iterations")
|
||||
```
|
||||
|
||||
### Hyperparameter Tuning
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
param_grid = {
|
||||
'learning_rate': [0.01, 0.05, 0.1],
|
||||
'n_estimators': [100, 200, 300],
|
||||
'max_depth': [3, 5, 7],
|
||||
'subsample': [0.8, 1.0]
|
||||
}
|
||||
|
||||
cv = GridSearchCV(
|
||||
GradientBoostingSurvivalAnalysis(),
|
||||
param_grid,
|
||||
scoring='concordance_index_ipcw',
|
||||
cv=5,
|
||||
n_jobs=-1
|
||||
)
|
||||
cv.fit(X, y)
|
||||
|
||||
best_model = cv.best_estimator_
|
||||
```
|
||||
|
||||
## ComponentwiseGradientBoostingSurvivalAnalysis
|
||||
|
||||
### Overview
|
||||
|
||||
Uses component-wise least squares, producing sparse linear models with automatic feature selection similar to Lasso.
|
||||
|
||||
### When to Use
|
||||
|
||||
- Want interpretable linear model
|
||||
- Need automatic feature selection
|
||||
- Have high-dimensional data with many irrelevant features
|
||||
- Prefer coefficient-based interpretation
|
||||
|
||||
### Example Usage
|
||||
|
||||
```python
|
||||
from sksurv.ensemble import ComponentwiseGradientBoostingSurvivalAnalysis
|
||||
|
||||
# Fit componentwise boosting
|
||||
cgbs = ComponentwiseGradientBoostingSurvivalAnalysis(
|
||||
loss='coxph',
|
||||
learning_rate=0.1,
|
||||
n_estimators=100
|
||||
)
|
||||
cgbs.fit(X, y)
|
||||
|
||||
# Get selected features and coefficients
|
||||
coef = cgbs.coef_
|
||||
selected_features = [i for i, c in enumerate(coef) if c != 0]
|
||||
```
|
||||
|
||||
## ExtraSurvivalTrees
|
||||
|
||||
Extremely randomized survival trees - similar to Random Survival Forest but with additional randomness in split selection.
|
||||
|
||||
### When to Use
|
||||
|
||||
- Want even more regularization than Random Survival Forest
|
||||
- Have limited data
|
||||
- Need faster training
|
||||
|
||||
### Key Difference
|
||||
|
||||
Instead of finding the best split for selected features, it randomly selects split points, adding more diversity to the ensemble.
|
||||
|
||||
```python
|
||||
from sksurv.ensemble import ExtraSurvivalTrees
|
||||
|
||||
est = ExtraSurvivalTrees(n_estimators=100, random_state=42)
|
||||
est.fit(X, y)
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
| Model | Complexity | Interpretability | Performance | Speed |
|
||||
|-------|-----------|------------------|-------------|-------|
|
||||
| Random Survival Forest | Medium | Low | High | Medium |
|
||||
| GradientBoostingSurvivalAnalysis | High | Low | Highest | Slow |
|
||||
| ComponentwiseGradientBoostingSurvivalAnalysis | Low | High | Medium | Fast |
|
||||
| ExtraSurvivalTrees | Medium | Low | Medium-High | Fast |
|
||||
|
||||
**General Recommendations:**
|
||||
- **Best overall performance**: GradientBoostingSurvivalAnalysis with tuning
|
||||
- **Best balance**: RandomSurvivalForest
|
||||
- **Best interpretability**: ComponentwiseGradientBoostingSurvivalAnalysis
|
||||
- **Fastest training**: ExtraSurvivalTrees
|
||||
Reference in New Issue
Block a user