Files
gh-k-dense-ai-claude-scient…/skills/pyhealth/references/training_evaluation.md
2025-11-30 08:30:10 +08:00

649 lines
17 KiB
Markdown

# PyHealth Training, Evaluation, and Interpretability
## Overview
PyHealth provides comprehensive tools for training models, evaluating predictions, ensuring model reliability, and interpreting results for clinical applications.
## Trainer Class
### Core Functionality
The `Trainer` class manages the complete model training and evaluation workflow with PyTorch integration.
**Initialization:**
```python
from pyhealth.trainer import Trainer
trainer = Trainer(
model=model, # PyHealth or PyTorch model
device="cuda", # or "cpu"
)
```
### Training
**train() method**
Trains models with comprehensive monitoring and checkpointing.
**Parameters:**
- `train_dataloader`: Training data loader
- `val_dataloader`: Validation data loader (optional)
- `test_dataloader`: Test data loader (optional)
- `epochs`: Number of training epochs
- `optimizer`: Optimizer instance or class
- `learning_rate`: Learning rate (default: 1e-3)
- `weight_decay`: L2 regularization (default: 0)
- `max_grad_norm`: Gradient clipping threshold
- `monitor`: Metric to monitor (e.g., "pr_auc_score")
- `monitor_criterion`: "max" or "min"
- `save_path`: Checkpoint save directory
**Usage:**
```python
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
epochs=50,
optimizer=torch.optim.Adam,
learning_rate=1e-3,
weight_decay=1e-5,
max_grad_norm=5.0,
monitor="pr_auc_score",
monitor_criterion="max",
save_path="./checkpoints"
)
```
**Training Features:**
1. **Automatic Checkpointing**: Saves best model based on monitored metric
2. **Early Stopping**: Stops training if no improvement
3. **Gradient Clipping**: Prevents exploding gradients
4. **Progress Tracking**: Displays training progress and metrics
5. **Multi-GPU Support**: Automatic device placement
### Inference
**inference() method**
Performs predictions on datasets.
**Parameters:**
- `dataloader`: Data loader for inference
- `additional_outputs`: List of additional outputs to return
- `return_patient_ids`: Return patient identifiers
**Usage:**
```python
predictions = trainer.inference(
dataloader=test_loader,
additional_outputs=["attention_weights", "embeddings"],
return_patient_ids=True
)
```
**Returns:**
- `y_pred`: Model predictions
- `y_true`: Ground truth labels
- `patient_ids`: Patient identifiers (if requested)
- Additional outputs (if specified)
### Evaluation
**evaluate() method**
Computes comprehensive evaluation metrics.
**Parameters:**
- `dataloader`: Data loader for evaluation
- `metrics`: List of metric functions
**Usage:**
```python
from pyhealth.metrics import binary_metrics_fn
results = trainer.evaluate(
dataloader=test_loader,
metrics=["accuracy", "pr_auc_score", "roc_auc_score", "f1_score"]
)
print(results)
# Output: {'accuracy': 0.85, 'pr_auc_score': 0.78, 'roc_auc_score': 0.82, 'f1_score': 0.73}
```
### Checkpoint Management
**save() method**
```python
trainer.save("./models/best_model.pt")
```
**load() method**
```python
trainer.load("./models/best_model.pt")
```
## Evaluation Metrics
### Binary Classification Metrics
**Available Metrics:**
- `accuracy`: Overall accuracy
- `precision`: Positive predictive value
- `recall`: Sensitivity/True positive rate
- `f1_score`: F1 score (harmonic mean of precision and recall)
- `roc_auc_score`: Area under ROC curve
- `pr_auc_score`: Area under precision-recall curve
- `cohen_kappa`: Inter-rater reliability
**Usage:**
```python
from pyhealth.metrics import binary_metrics_fn
# Comprehensive binary metrics
metrics = binary_metrics_fn(
y_true=labels,
y_pred=predictions,
metrics=["accuracy", "f1_score", "pr_auc_score", "roc_auc_score"]
)
```
**Threshold Selection:**
```python
# Default threshold: 0.5
predictions_binary = (predictions > 0.5).astype(int)
# Optimal threshold by F1
from sklearn.metrics import f1_score
thresholds = np.arange(0.1, 0.9, 0.05)
f1_scores = [f1_score(y_true, (y_pred > t).astype(int)) for t in thresholds]
optimal_threshold = thresholds[np.argmax(f1_scores)]
```
**Best Practices:**
- **Use AUROC**: Overall model discrimination
- **Use AUPRC**: Especially for imbalanced classes
- **Use F1**: Balance precision and recall
- **Report confidence intervals**: Bootstrap resampling
### Multi-Class Classification Metrics
**Available Metrics:**
- `accuracy`: Overall accuracy
- `macro_f1`: Unweighted mean F1 across classes
- `micro_f1`: Global F1 (total TP, FP, FN)
- `weighted_f1`: Weighted mean F1 by class frequency
- `cohen_kappa`: Multi-class kappa
**Usage:**
```python
from pyhealth.metrics import multiclass_metrics_fn
metrics = multiclass_metrics_fn(
y_true=labels,
y_pred=predictions,
metrics=["accuracy", "macro_f1", "weighted_f1"]
)
```
**Per-Class Metrics:**
```python
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred,
target_names=["Wake", "N1", "N2", "N3", "REM"]))
```
**Confusion Matrix:**
```python
from sklearn.metrics import confusion_matrix
import seaborn as sns
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
```
### Multi-Label Classification Metrics
**Available Metrics:**
- `jaccard_score`: Intersection over union
- `hamming_loss`: Fraction of incorrect labels
- `example_f1`: F1 per example (micro average)
- `label_f1`: F1 per label (macro average)
**Usage:**
```python
from pyhealth.metrics import multilabel_metrics_fn
# y_pred: [n_samples, n_labels] binary matrix
metrics = multilabel_metrics_fn(
y_true=label_matrix,
y_pred=pred_matrix,
metrics=["jaccard_score", "example_f1", "label_f1"]
)
```
**Drug Recommendation Metrics:**
```python
# Jaccard similarity (intersection/union)
jaccard = len(set(true_drugs) & set(pred_drugs)) / len(set(true_drugs) | set(pred_drugs))
# Precision@k: Precision for top-k predictions
def precision_at_k(y_true, y_pred, k=10):
top_k_pred = y_pred.argsort()[-k:]
return len(set(y_true) & set(top_k_pred)) / k
```
### Regression Metrics
**Available Metrics:**
- `mean_absolute_error`: Average absolute error
- `mean_squared_error`: Average squared error
- `root_mean_squared_error`: RMSE
- `r2_score`: Coefficient of determination
**Usage:**
```python
from pyhealth.metrics import regression_metrics_fn
metrics = regression_metrics_fn(
y_true=true_values,
y_pred=predictions,
metrics=["mae", "rmse", "r2"]
)
```
**Percentage Error Metrics:**
```python
# Mean Absolute Percentage Error
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
# Median Absolute Percentage Error (robust to outliers)
medape = np.median(np.abs((y_true - y_pred) / y_true)) * 100
```
### Fairness Metrics
**Purpose:** Assess model bias across demographic groups
**Available Metrics:**
- `demographic_parity`: Equal positive prediction rates
- `equalized_odds`: Equal TPR and FPR across groups
- `equal_opportunity`: Equal TPR across groups
- `predictive_parity`: Equal PPV across groups
**Usage:**
```python
from pyhealth.metrics import fairness_metrics_fn
fairness_results = fairness_metrics_fn(
y_true=labels,
y_pred=predictions,
sensitive_attributes=demographics, # e.g., race, gender
metrics=["demographic_parity", "equalized_odds"]
)
```
**Example:**
```python
# Evaluate fairness across gender
male_mask = (demographics == "male")
female_mask = (demographics == "female")
male_tpr = recall_score(y_true[male_mask], y_pred[male_mask])
female_tpr = recall_score(y_true[female_mask], y_pred[female_mask])
tpr_disparity = abs(male_tpr - female_tpr)
print(f"TPR disparity: {tpr_disparity:.3f}")
```
## Calibration and Uncertainty Quantification
### Model Calibration
**Purpose:** Ensure predicted probabilities match actual frequencies
**Calibration Plot:**
```python
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt
fraction_of_positives, mean_predicted_value = calibration_curve(
y_true, y_prob, n_bins=10
)
plt.plot(mean_predicted_value, fraction_of_positives, marker='o')
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect calibration')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.legend()
```
**Expected Calibration Error (ECE):**
```python
def expected_calibration_error(y_true, y_prob, n_bins=10):
"""Compute ECE"""
bins = np.linspace(0, 1, n_bins + 1)
bin_indices = np.digitize(y_prob, bins) - 1
ece = 0
for i in range(n_bins):
mask = bin_indices == i
if mask.sum() > 0:
bin_accuracy = y_true[mask].mean()
bin_confidence = y_prob[mask].mean()
ece += mask.sum() / len(y_true) * abs(bin_accuracy - bin_confidence)
return ece
```
**Calibration Methods:**
1. **Platt Scaling**: Logistic regression on validation predictions
```python
from sklearn.linear_model import LogisticRegression
calibrator = LogisticRegression()
calibrator.fit(val_predictions.reshape(-1, 1), val_labels)
calibrated_probs = calibrator.predict_proba(test_predictions.reshape(-1, 1))[:, 1]
```
2. **Isotonic Regression**: Non-parametric calibration
```python
from sklearn.isotonic import IsotonicRegression
calibrator = IsotonicRegression(out_of_bounds='clip')
calibrator.fit(val_predictions, val_labels)
calibrated_probs = calibrator.predict(test_predictions)
```
3. **Temperature Scaling**: Scale logits before softmax
```python
def find_temperature(logits, labels):
"""Find optimal temperature parameter"""
from scipy.optimize import minimize
def nll(temp):
scaled_logits = logits / temp
probs = torch.softmax(scaled_logits, dim=1)
return F.cross_entropy(probs, labels).item()
result = minimize(nll, x0=1.0, method='BFGS')
return result.x[0]
temperature = find_temperature(val_logits, val_labels)
calibrated_logits = test_logits / temperature
```
### Uncertainty Quantification
**Conformal Prediction:**
Provide prediction sets with guaranteed coverage.
**Usage:**
```python
from pyhealth.metrics import prediction_set_metrics_fn
# Calibrate on validation set
scores = 1 - val_predictions[np.arange(len(val_labels)), val_labels]
quantile_level = np.quantile(scores, 0.9) # 90% coverage
# Generate prediction sets on test set
prediction_sets = test_predictions > (1 - quantile_level)
# Evaluate
metrics = prediction_set_metrics_fn(
y_true=test_labels,
prediction_sets=prediction_sets,
metrics=["coverage", "average_size"]
)
```
**Monte Carlo Dropout:**
Estimate uncertainty through dropout at inference.
```python
def predict_with_uncertainty(model, dataloader, num_samples=20):
"""Predict with uncertainty using MC dropout"""
model.train() # Keep dropout active
predictions = []
for _ in range(num_samples):
batch_preds = []
for batch in dataloader:
with torch.no_grad():
output = model(batch)
batch_preds.append(output)
predictions.append(torch.cat(batch_preds))
predictions = torch.stack(predictions)
mean_pred = predictions.mean(dim=0)
std_pred = predictions.std(dim=0) # Uncertainty
return mean_pred, std_pred
```
**Ensemble Uncertainty:**
```python
# Train multiple models
models = [train_model(seed=i) for i in range(5)]
# Predict with ensemble
ensemble_preds = []
for model in models:
pred = model.predict(test_data)
ensemble_preds.append(pred)
mean_pred = np.mean(ensemble_preds, axis=0)
std_pred = np.std(ensemble_preds, axis=0) # Uncertainty
```
## Interpretability
### Attention Visualization
**For Transformer and RETAIN models:**
```python
# Get attention weights during inference
outputs = trainer.inference(
test_loader,
additional_outputs=["attention_weights"]
)
attention = outputs["attention_weights"]
# Visualize attention for sample
import matplotlib.pyplot as plt
import seaborn as sns
sample_idx = 0
sample_attention = attention[sample_idx] # [seq_length, seq_length]
sns.heatmap(sample_attention, cmap='viridis')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Attention Weights')
plt.show()
```
**RETAIN Interpretation:**
```python
# RETAIN provides visit-level and feature-level attention
visit_attention = outputs["visit_attention"] # Which visits are important
feature_attention = outputs["feature_attention"] # Which features are important
# Find most influential visit
most_important_visit = visit_attention[sample_idx].argmax()
# Find most influential features in that visit
important_features = feature_attention[sample_idx, most_important_visit].argsort()[-10:]
```
### Feature Importance
**Permutation Importance:**
```python
from sklearn.inspection import permutation_importance
def get_predictions(model, X):
return model.predict(X)
result = permutation_importance(
model, X_test, y_test,
n_repeats=10,
scoring='roc_auc'
)
# Sort features by importance
indices = result.importances_mean.argsort()[::-1]
for i in indices[:10]:
print(f"{feature_names[i]}: {result.importances_mean[i]:.3f}")
```
**SHAP Values:**
```python
import shap
# Create explainer
explainer = shap.DeepExplainer(model, train_data)
# Compute SHAP values
shap_values = explainer.shap_values(test_data)
# Visualize
shap.summary_plot(shap_values, test_data, feature_names=feature_names)
```
### ChEFER (Clinical Health Event Feature Extraction and Ranking)
**PyHealth's Interpretability Tool:**
```python
from pyhealth.explain import ChEFER
explainer = ChEFER(model=model, dataset=test_dataset)
# Get feature importance for prediction
importance_scores = explainer.explain(
patient_id="patient_123",
visit_id="visit_456"
)
# Visualize top features
explainer.plot_importance(importance_scores, top_k=20)
```
## Complete Training Pipeline Example
```python
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer
from pyhealth.metrics import binary_metrics_fn
# 1. Load and prepare data
dataset = MIMIC4Dataset(root="/path/to/mimic4")
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
# 2. Split data
train_data, val_data, test_data = split_by_patient(
sample_dataset, ratios=[0.7, 0.1, 0.2], seed=42
)
# 3. Create data loaders
train_loader = get_dataloader(train_data, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = get_dataloader(test_data, batch_size=64, shuffle=False)
# 4. Initialize model
model = Transformer(
dataset=sample_dataset,
feature_keys=["diagnoses", "procedures", "medications"],
mode="binary",
embedding_dim=128,
num_heads=8,
num_layers=3,
dropout=0.3
)
# 5. Train model
trainer = Trainer(model=model, device="cuda")
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=50,
optimizer=torch.optim.Adam,
learning_rate=1e-3,
weight_decay=1e-5,
monitor="pr_auc_score",
monitor_criterion="max",
save_path="./checkpoints/mortality_model"
)
# 6. Evaluate on test set
test_results = trainer.evaluate(
test_loader,
metrics=["accuracy", "precision", "recall", "f1_score",
"roc_auc_score", "pr_auc_score"]
)
print("Test Results:")
for metric, value in test_results.items():
print(f"{metric}: {value:.4f}")
# 7. Get predictions for analysis
predictions = trainer.inference(test_loader, return_patient_ids=True)
y_pred, y_true, patient_ids = predictions
# 8. Calibration analysis
from sklearn.calibration import calibration_curve
fraction_pos, mean_pred = calibration_curve(y_true, y_pred, n_bins=10)
ece = expected_calibration_error(y_true, y_pred)
print(f"Expected Calibration Error: {ece:.4f}")
# 9. Save final model
trainer.save("./models/mortality_transformer_final.pt")
```
## Best Practices
### Training
1. **Monitor multiple metrics**: Track both loss and task-specific metrics
2. **Use validation set**: Prevent overfitting with early stopping
3. **Gradient clipping**: Stabilize training (max_grad_norm=5.0)
4. **Learning rate scheduling**: Reduce LR on plateau
5. **Checkpoint best model**: Save based on validation performance
### Evaluation
1. **Use task-appropriate metrics**: AUROC/AUPRC for binary, macro-F1 for imbalanced multi-class
2. **Report confidence intervals**: Bootstrap or cross-validation
3. **Stratified evaluation**: Report metrics by subgroups
4. **Clinical metrics**: Include clinically relevant thresholds
5. **Fairness assessment**: Evaluate across demographic groups
### Deployment
1. **Calibrate predictions**: Ensure probabilities are reliable
2. **Quantify uncertainty**: Provide confidence estimates
3. **Monitor performance**: Track metrics in production
4. **Handle distribution shift**: Detect when data changes
5. **Interpretability**: Provide explanations for predictions