# PyHealth Training, Evaluation, and Interpretability ## Overview PyHealth provides comprehensive tools for training models, evaluating predictions, ensuring model reliability, and interpreting results for clinical applications. ## Trainer Class ### Core Functionality The `Trainer` class manages the complete model training and evaluation workflow with PyTorch integration. **Initialization:** ```python from pyhealth.trainer import Trainer trainer = Trainer( model=model, # PyHealth or PyTorch model device="cuda", # or "cpu" ) ``` ### Training **train() method** Trains models with comprehensive monitoring and checkpointing. **Parameters:** - `train_dataloader`: Training data loader - `val_dataloader`: Validation data loader (optional) - `test_dataloader`: Test data loader (optional) - `epochs`: Number of training epochs - `optimizer`: Optimizer instance or class - `learning_rate`: Learning rate (default: 1e-3) - `weight_decay`: L2 regularization (default: 0) - `max_grad_norm`: Gradient clipping threshold - `monitor`: Metric to monitor (e.g., "pr_auc_score") - `monitor_criterion`: "max" or "min" - `save_path`: Checkpoint save directory **Usage:** ```python trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, test_dataloader=test_loader, epochs=50, optimizer=torch.optim.Adam, learning_rate=1e-3, weight_decay=1e-5, max_grad_norm=5.0, monitor="pr_auc_score", monitor_criterion="max", save_path="./checkpoints" ) ``` **Training Features:** 1. **Automatic Checkpointing**: Saves best model based on monitored metric 2. **Early Stopping**: Stops training if no improvement 3. **Gradient Clipping**: Prevents exploding gradients 4. **Progress Tracking**: Displays training progress and metrics 5. **Multi-GPU Support**: Automatic device placement ### Inference **inference() method** Performs predictions on datasets. **Parameters:** - `dataloader`: Data loader for inference - `additional_outputs`: List of additional outputs to return - `return_patient_ids`: Return patient identifiers **Usage:** ```python predictions = trainer.inference( dataloader=test_loader, additional_outputs=["attention_weights", "embeddings"], return_patient_ids=True ) ``` **Returns:** - `y_pred`: Model predictions - `y_true`: Ground truth labels - `patient_ids`: Patient identifiers (if requested) - Additional outputs (if specified) ### Evaluation **evaluate() method** Computes comprehensive evaluation metrics. **Parameters:** - `dataloader`: Data loader for evaluation - `metrics`: List of metric functions **Usage:** ```python from pyhealth.metrics import binary_metrics_fn results = trainer.evaluate( dataloader=test_loader, metrics=["accuracy", "pr_auc_score", "roc_auc_score", "f1_score"] ) print(results) # Output: {'accuracy': 0.85, 'pr_auc_score': 0.78, 'roc_auc_score': 0.82, 'f1_score': 0.73} ``` ### Checkpoint Management **save() method** ```python trainer.save("./models/best_model.pt") ``` **load() method** ```python trainer.load("./models/best_model.pt") ``` ## Evaluation Metrics ### Binary Classification Metrics **Available Metrics:** - `accuracy`: Overall accuracy - `precision`: Positive predictive value - `recall`: Sensitivity/True positive rate - `f1_score`: F1 score (harmonic mean of precision and recall) - `roc_auc_score`: Area under ROC curve - `pr_auc_score`: Area under precision-recall curve - `cohen_kappa`: Inter-rater reliability **Usage:** ```python from pyhealth.metrics import binary_metrics_fn # Comprehensive binary metrics metrics = binary_metrics_fn( y_true=labels, y_pred=predictions, metrics=["accuracy", "f1_score", "pr_auc_score", "roc_auc_score"] ) ``` **Threshold Selection:** ```python # Default threshold: 0.5 predictions_binary = (predictions > 0.5).astype(int) # Optimal threshold by F1 from sklearn.metrics import f1_score thresholds = np.arange(0.1, 0.9, 0.05) f1_scores = [f1_score(y_true, (y_pred > t).astype(int)) for t in thresholds] optimal_threshold = thresholds[np.argmax(f1_scores)] ``` **Best Practices:** - **Use AUROC**: Overall model discrimination - **Use AUPRC**: Especially for imbalanced classes - **Use F1**: Balance precision and recall - **Report confidence intervals**: Bootstrap resampling ### Multi-Class Classification Metrics **Available Metrics:** - `accuracy`: Overall accuracy - `macro_f1`: Unweighted mean F1 across classes - `micro_f1`: Global F1 (total TP, FP, FN) - `weighted_f1`: Weighted mean F1 by class frequency - `cohen_kappa`: Multi-class kappa **Usage:** ```python from pyhealth.metrics import multiclass_metrics_fn metrics = multiclass_metrics_fn( y_true=labels, y_pred=predictions, metrics=["accuracy", "macro_f1", "weighted_f1"] ) ``` **Per-Class Metrics:** ```python from sklearn.metrics import classification_report print(classification_report(y_true, y_pred, target_names=["Wake", "N1", "N2", "N3", "REM"])) ``` **Confusion Matrix:** ```python from sklearn.metrics import confusion_matrix import seaborn as sns cm = confusion_matrix(y_true, y_pred) sns.heatmap(cm, annot=True, fmt='d') ``` ### Multi-Label Classification Metrics **Available Metrics:** - `jaccard_score`: Intersection over union - `hamming_loss`: Fraction of incorrect labels - `example_f1`: F1 per example (micro average) - `label_f1`: F1 per label (macro average) **Usage:** ```python from pyhealth.metrics import multilabel_metrics_fn # y_pred: [n_samples, n_labels] binary matrix metrics = multilabel_metrics_fn( y_true=label_matrix, y_pred=pred_matrix, metrics=["jaccard_score", "example_f1", "label_f1"] ) ``` **Drug Recommendation Metrics:** ```python # Jaccard similarity (intersection/union) jaccard = len(set(true_drugs) & set(pred_drugs)) / len(set(true_drugs) | set(pred_drugs)) # Precision@k: Precision for top-k predictions def precision_at_k(y_true, y_pred, k=10): top_k_pred = y_pred.argsort()[-k:] return len(set(y_true) & set(top_k_pred)) / k ``` ### Regression Metrics **Available Metrics:** - `mean_absolute_error`: Average absolute error - `mean_squared_error`: Average squared error - `root_mean_squared_error`: RMSE - `r2_score`: Coefficient of determination **Usage:** ```python from pyhealth.metrics import regression_metrics_fn metrics = regression_metrics_fn( y_true=true_values, y_pred=predictions, metrics=["mae", "rmse", "r2"] ) ``` **Percentage Error Metrics:** ```python # Mean Absolute Percentage Error mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100 # Median Absolute Percentage Error (robust to outliers) medape = np.median(np.abs((y_true - y_pred) / y_true)) * 100 ``` ### Fairness Metrics **Purpose:** Assess model bias across demographic groups **Available Metrics:** - `demographic_parity`: Equal positive prediction rates - `equalized_odds`: Equal TPR and FPR across groups - `equal_opportunity`: Equal TPR across groups - `predictive_parity`: Equal PPV across groups **Usage:** ```python from pyhealth.metrics import fairness_metrics_fn fairness_results = fairness_metrics_fn( y_true=labels, y_pred=predictions, sensitive_attributes=demographics, # e.g., race, gender metrics=["demographic_parity", "equalized_odds"] ) ``` **Example:** ```python # Evaluate fairness across gender male_mask = (demographics == "male") female_mask = (demographics == "female") male_tpr = recall_score(y_true[male_mask], y_pred[male_mask]) female_tpr = recall_score(y_true[female_mask], y_pred[female_mask]) tpr_disparity = abs(male_tpr - female_tpr) print(f"TPR disparity: {tpr_disparity:.3f}") ``` ## Calibration and Uncertainty Quantification ### Model Calibration **Purpose:** Ensure predicted probabilities match actual frequencies **Calibration Plot:** ```python from sklearn.calibration import calibration_curve import matplotlib.pyplot as plt fraction_of_positives, mean_predicted_value = calibration_curve( y_true, y_prob, n_bins=10 ) plt.plot(mean_predicted_value, fraction_of_positives, marker='o') plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect calibration') plt.xlabel('Mean predicted probability') plt.ylabel('Fraction of positives') plt.legend() ``` **Expected Calibration Error (ECE):** ```python def expected_calibration_error(y_true, y_prob, n_bins=10): """Compute ECE""" bins = np.linspace(0, 1, n_bins + 1) bin_indices = np.digitize(y_prob, bins) - 1 ece = 0 for i in range(n_bins): mask = bin_indices == i if mask.sum() > 0: bin_accuracy = y_true[mask].mean() bin_confidence = y_prob[mask].mean() ece += mask.sum() / len(y_true) * abs(bin_accuracy - bin_confidence) return ece ``` **Calibration Methods:** 1. **Platt Scaling**: Logistic regression on validation predictions ```python from sklearn.linear_model import LogisticRegression calibrator = LogisticRegression() calibrator.fit(val_predictions.reshape(-1, 1), val_labels) calibrated_probs = calibrator.predict_proba(test_predictions.reshape(-1, 1))[:, 1] ``` 2. **Isotonic Regression**: Non-parametric calibration ```python from sklearn.isotonic import IsotonicRegression calibrator = IsotonicRegression(out_of_bounds='clip') calibrator.fit(val_predictions, val_labels) calibrated_probs = calibrator.predict(test_predictions) ``` 3. **Temperature Scaling**: Scale logits before softmax ```python def find_temperature(logits, labels): """Find optimal temperature parameter""" from scipy.optimize import minimize def nll(temp): scaled_logits = logits / temp probs = torch.softmax(scaled_logits, dim=1) return F.cross_entropy(probs, labels).item() result = minimize(nll, x0=1.0, method='BFGS') return result.x[0] temperature = find_temperature(val_logits, val_labels) calibrated_logits = test_logits / temperature ``` ### Uncertainty Quantification **Conformal Prediction:** Provide prediction sets with guaranteed coverage. **Usage:** ```python from pyhealth.metrics import prediction_set_metrics_fn # Calibrate on validation set scores = 1 - val_predictions[np.arange(len(val_labels)), val_labels] quantile_level = np.quantile(scores, 0.9) # 90% coverage # Generate prediction sets on test set prediction_sets = test_predictions > (1 - quantile_level) # Evaluate metrics = prediction_set_metrics_fn( y_true=test_labels, prediction_sets=prediction_sets, metrics=["coverage", "average_size"] ) ``` **Monte Carlo Dropout:** Estimate uncertainty through dropout at inference. ```python def predict_with_uncertainty(model, dataloader, num_samples=20): """Predict with uncertainty using MC dropout""" model.train() # Keep dropout active predictions = [] for _ in range(num_samples): batch_preds = [] for batch in dataloader: with torch.no_grad(): output = model(batch) batch_preds.append(output) predictions.append(torch.cat(batch_preds)) predictions = torch.stack(predictions) mean_pred = predictions.mean(dim=0) std_pred = predictions.std(dim=0) # Uncertainty return mean_pred, std_pred ``` **Ensemble Uncertainty:** ```python # Train multiple models models = [train_model(seed=i) for i in range(5)] # Predict with ensemble ensemble_preds = [] for model in models: pred = model.predict(test_data) ensemble_preds.append(pred) mean_pred = np.mean(ensemble_preds, axis=0) std_pred = np.std(ensemble_preds, axis=0) # Uncertainty ``` ## Interpretability ### Attention Visualization **For Transformer and RETAIN models:** ```python # Get attention weights during inference outputs = trainer.inference( test_loader, additional_outputs=["attention_weights"] ) attention = outputs["attention_weights"] # Visualize attention for sample import matplotlib.pyplot as plt import seaborn as sns sample_idx = 0 sample_attention = attention[sample_idx] # [seq_length, seq_length] sns.heatmap(sample_attention, cmap='viridis') plt.xlabel('Key Position') plt.ylabel('Query Position') plt.title('Attention Weights') plt.show() ``` **RETAIN Interpretation:** ```python # RETAIN provides visit-level and feature-level attention visit_attention = outputs["visit_attention"] # Which visits are important feature_attention = outputs["feature_attention"] # Which features are important # Find most influential visit most_important_visit = visit_attention[sample_idx].argmax() # Find most influential features in that visit important_features = feature_attention[sample_idx, most_important_visit].argsort()[-10:] ``` ### Feature Importance **Permutation Importance:** ```python from sklearn.inspection import permutation_importance def get_predictions(model, X): return model.predict(X) result = permutation_importance( model, X_test, y_test, n_repeats=10, scoring='roc_auc' ) # Sort features by importance indices = result.importances_mean.argsort()[::-1] for i in indices[:10]: print(f"{feature_names[i]}: {result.importances_mean[i]:.3f}") ``` **SHAP Values:** ```python import shap # Create explainer explainer = shap.DeepExplainer(model, train_data) # Compute SHAP values shap_values = explainer.shap_values(test_data) # Visualize shap.summary_plot(shap_values, test_data, feature_names=feature_names) ``` ### ChEFER (Clinical Health Event Feature Extraction and Ranking) **PyHealth's Interpretability Tool:** ```python from pyhealth.explain import ChEFER explainer = ChEFER(model=model, dataset=test_dataset) # Get feature importance for prediction importance_scores = explainer.explain( patient_id="patient_123", visit_id="visit_456" ) # Visualize top features explainer.plot_importance(importance_scores, top_k=20) ``` ## Complete Training Pipeline Example ```python from pyhealth.datasets import MIMIC4Dataset from pyhealth.tasks import mortality_prediction_mimic4_fn from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.models import Transformer from pyhealth.trainer import Trainer from pyhealth.metrics import binary_metrics_fn # 1. Load and prepare data dataset = MIMIC4Dataset(root="/path/to/mimic4") sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn) # 2. Split data train_data, val_data, test_data = split_by_patient( sample_dataset, ratios=[0.7, 0.1, 0.2], seed=42 ) # 3. Create data loaders train_loader = get_dataloader(train_data, batch_size=64, shuffle=True) val_loader = get_dataloader(val_data, batch_size=64, shuffle=False) test_loader = get_dataloader(test_data, batch_size=64, shuffle=False) # 4. Initialize model model = Transformer( dataset=sample_dataset, feature_keys=["diagnoses", "procedures", "medications"], mode="binary", embedding_dim=128, num_heads=8, num_layers=3, dropout=0.3 ) # 5. Train model trainer = Trainer(model=model, device="cuda") trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, epochs=50, optimizer=torch.optim.Adam, learning_rate=1e-3, weight_decay=1e-5, monitor="pr_auc_score", monitor_criterion="max", save_path="./checkpoints/mortality_model" ) # 6. Evaluate on test set test_results = trainer.evaluate( test_loader, metrics=["accuracy", "precision", "recall", "f1_score", "roc_auc_score", "pr_auc_score"] ) print("Test Results:") for metric, value in test_results.items(): print(f"{metric}: {value:.4f}") # 7. Get predictions for analysis predictions = trainer.inference(test_loader, return_patient_ids=True) y_pred, y_true, patient_ids = predictions # 8. Calibration analysis from sklearn.calibration import calibration_curve fraction_pos, mean_pred = calibration_curve(y_true, y_pred, n_bins=10) ece = expected_calibration_error(y_true, y_pred) print(f"Expected Calibration Error: {ece:.4f}") # 9. Save final model trainer.save("./models/mortality_transformer_final.pt") ``` ## Best Practices ### Training 1. **Monitor multiple metrics**: Track both loss and task-specific metrics 2. **Use validation set**: Prevent overfitting with early stopping 3. **Gradient clipping**: Stabilize training (max_grad_norm=5.0) 4. **Learning rate scheduling**: Reduce LR on plateau 5. **Checkpoint best model**: Save based on validation performance ### Evaluation 1. **Use task-appropriate metrics**: AUROC/AUPRC for binary, macro-F1 for imbalanced multi-class 2. **Report confidence intervals**: Bootstrap or cross-validation 3. **Stratified evaluation**: Report metrics by subgroups 4. **Clinical metrics**: Include clinically relevant thresholds 5. **Fairness assessment**: Evaluate across demographic groups ### Deployment 1. **Calibrate predictions**: Ensure probabilities are reliable 2. **Quantify uncertainty**: Provide confidence estimates 3. **Monitor performance**: Track metrics in production 4. **Handle distribution shift**: Detect when data changes 5. **Interpretability**: Provide explanations for predictions