11 KiB
PyHealth Clinical Prediction Tasks
Overview
PyHealth provides 20+ predefined clinical prediction tasks for common healthcare AI applications. Each task function transforms raw patient data into structured input-output pairs for model training.
Task Function Structure
All task functions inherit from BaseTask and provide:
- input_schema: Defines input features (diagnoses, medications, labs, etc.)
- output_schema: Defines prediction targets (labels, values)
- pre_filter(): Optional patient/visit filtering logic
Usage Pattern:
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
dataset = MIMIC4Dataset(root="/path/to/data")
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
Electronic Health Record (EHR) Tasks
Mortality Prediction
Purpose: Predict patient death risk at next visit or within specified timeframe
MIMIC-III Mortality (mortality_prediction_mimic3_fn)
- Predicts death at next hospital visit
- Binary classification task
- Input: Historical diagnoses, procedures, medications
- Output: Binary label (deceased/alive)
MIMIC-IV Mortality (mortality_prediction_mimic4_fn)
- Updated version for MIMIC-IV dataset
- Enhanced feature set
- Improved label quality
eICU Mortality (mortality_prediction_eicu_fn)
- Multi-center ICU mortality prediction
- Accounts for hospital-level variation
OMOP Mortality (mortality_prediction_omop_fn)
- Standardized mortality prediction
- Works with OMOP common data model
In-Hospital Mortality (inhospital_mortality_prediction_mimic4_fn)
- Predicts death during current hospitalization
- Real-time risk assessment
- Earlier prediction window than next-visit mortality
StageNet Mortality (mortality_prediction_mimic4_fn_stagenet)
- Specialized for StageNet model architecture
- Temporal stage-aware prediction
Hospital Readmission Prediction
Purpose: Identify patients at risk of hospital readmission within specified timeframe (typically 30 days)
MIMIC-III Readmission (readmission_prediction_mimic3_fn)
- 30-day readmission prediction
- Binary classification
- Input: Diagnosis history, medications, demographics
- Output: Binary label (readmitted/not readmitted)
MIMIC-IV Readmission (readmission_prediction_mimic4_fn)
- Enhanced readmission features
- Improved temporal modeling
eICU Readmission (readmission_prediction_eicu_fn)
- ICU-specific readmission risk
- Multi-site data
OMOP Readmission (readmission_prediction_omop_fn)
- Standardized readmission prediction
Length of Stay Prediction
Purpose: Estimate hospital stay duration for resource planning and patient management
MIMIC-III Length of Stay (length_of_stay_prediction_mimic3_fn)
- Regression task
- Input: Admission diagnoses, vitals, demographics
- Output: Continuous value (days)
MIMIC-IV Length of Stay (length_of_stay_prediction_mimic4_fn)
- Enhanced features for LOS prediction
- Better temporal granularity
eICU Length of Stay (length_of_stay_prediction_eicu_fn)
- ICU stay duration prediction
- Multi-hospital data
OMOP Length of Stay (length_of_stay_prediction_omop_fn)
- Standardized LOS prediction
Drug Recommendation
Purpose: Suggest appropriate medications based on patient history and current conditions
MIMIC-III Drug Recommendation (drug_recommendation_mimic3_fn)
- Multi-label classification
- Input: Diagnoses, previous medications, demographics
- Output: Set of recommended drug codes
- Considers drug-drug interactions
MIMIC-IV Drug Recommendation (drug_recommendation_mimic4_fn)
- Updated medication data
- Enhanced interaction modeling
eICU Drug Recommendation (drug_recommendation_eicu_fn)
- Critical care medication recommendations
OMOP Drug Recommendation (drug_recommendation_omop_fn)
- Standardized drug recommendation
Key Considerations:
- Handles polypharmacy scenarios
- Multi-label prediction (multiple drugs per patient)
- Can integrate with SafeDrug/GAMENet models for safety-aware recommendations
Specialized Clinical Tasks
Medical Coding
MIMIC-III ICD-9 Coding (icd9_coding_mimic3_fn)
- Assigns ICD-9 diagnosis/procedure codes to clinical notes
- Multi-label text classification
- Input: Clinical text/documentation
- Output: Set of ICD-9 codes
- Supports both diagnosis and procedure coding
Patient Linkage
MIMIC-III Patient Linking (patient_linkage_mimic3_fn)
- Record matching and deduplication
- Binary classification (same patient or not)
- Input: Demographic and clinical features from two records
- Output: Match probability
Physiological Signal Tasks
Sleep Staging
Purpose: Classify sleep stages from EEG/physiological signals for sleep disorder diagnosis
ISRUC Sleep Staging (sleep_staging_isruc_fn)
- Multi-class classification (Wake, N1, N2, N3, REM)
- Input: Multi-channel EEG signals
- Output: Sleep stage per epoch (typically 30 seconds)
SleepEDF Sleep Staging (sleep_staging_sleepedf_fn)
- Standard sleep staging task
- PSG signal processing
SHHS Sleep Staging (sleep_staging_shhs_fn)
- Large-scale sleep study data
- Population-level sleep analysis
Standardized Labels:
- Wake (W)
- Non-REM Stage 1 (N1)
- Non-REM Stage 2 (N2)
- Non-REM Stage 3 (N3/Deep Sleep)
- REM (Rapid Eye Movement)
EEG Analysis
Abnormality Detection (abnormality_detection_tuab_fn)
- Binary classification (normal/abnormal EEG)
- Clinical screening application
- Input: Multi-channel EEG recordings
- Output: Binary label
Event Detection (event_detection_tuev_fn)
- Identify specific EEG events (spikes, seizures)
- Multi-class classification
- Input: EEG time series
- Output: Event type and timing
Seizure Detection (seizure_detection_tusz_fn)
- Specialized epileptic seizure detection
- Critical for epilepsy monitoring
- Input: Continuous EEG
- Output: Seizure/non-seizure classification
Medical Imaging Tasks
COVID-19 Chest X-ray Classification
COVID-19 CXR (covid_classification_cxr_fn)
- Multi-class image classification
- Classes: COVID-19, bacterial pneumonia, viral pneumonia, normal
- Input: Chest X-ray images
- Output: Disease classification
Text-Based Tasks
Medical Transcription Classification
Medical Specialty Classification (medical_transcription_classification_fn)
- Classify clinical notes by medical specialty
- Multi-class text classification
- Input: Clinical transcription text
- Output: Medical specialty (Cardiology, Neurology, etc.)
Custom Task Creation
Creating Custom Tasks
Define custom prediction tasks by specifying input/output schemas:
from pyhealth.tasks import BaseTask
def custom_task_fn(patient):
"""Custom prediction task"""
# Define input features
samples = []
for i, visit in enumerate(patient.visits):
# Skip if not enough history
if i < 2:
continue
# Create input from historical visits
input_info = {
"diagnoses": [],
"medications": [],
"procedures": []
}
# Collect features from previous visits
for past_visit in patient.visits[:i]:
for event in past_visit.events:
if event.vocabulary == "ICD10CM":
input_info["diagnoses"].append(event.code)
elif event.vocabulary == "NDC":
input_info["medications"].append(event.code)
# Define prediction target
# Example: predict specific outcome at current visit
output_info = {
"label": 1 if some_condition else 0
}
samples.append({
"patient_id": patient.patient_id,
"visit_id": visit.visit_id,
"input_info": input_info,
"output_info": output_info
})
return samples
# Apply custom task
sample_dataset = dataset.set_task(custom_task_fn)
Task Function Components
-
Input Schema Definition
- Specify which features to extract
- Define feature types (codes, sequences, values)
- Set temporal windows
-
Output Schema Definition
- Define prediction targets
- Set label types (binary, multi-class, multi-label, regression)
- Specify evaluation metrics
-
Filtering Logic
- Exclude patients/visits with insufficient data
- Apply inclusion/exclusion criteria
- Handle missing data
-
Sample Generation
- Create input-output pairs
- Maintain patient/visit identifiers
- Preserve temporal ordering
Task Selection Guidelines
Clinical Prediction Tasks
Use when: Working with structured EHR data (diagnoses, medications, procedures)
Datasets: MIMIC-III, MIMIC-IV, eICU, OMOP
Common tasks:
- Mortality prediction for risk stratification
- Readmission prediction for care transition planning
- Length of stay for resource allocation
- Drug recommendation for clinical decision support
Signal Processing Tasks
Use when: Working with physiological time-series data
Datasets: SleepEDF, SHHS, ISRUC, TUEV, TUAB, TUSZ
Common tasks:
- Sleep staging for sleep disorder diagnosis
- EEG abnormality detection for screening
- Seizure detection for epilepsy monitoring
Imaging Tasks
Use when: Working with medical images
Datasets: COVID-19 CXR
Common tasks:
- Disease classification from radiographs
- Abnormality detection
Text Tasks
Use when: Working with clinical notes and documentation
Datasets: Medical Transcriptions, MIMIC-III (with notes)
Common tasks:
- Medical coding from clinical text
- Specialty classification
- Clinical information extraction
Task Output Structure
All task functions return SampleDataset with:
sample = {
"patient_id": "unique_patient_id",
"visit_id": "unique_visit_id", # if applicable
"input_info": {
# Input features (diagnoses, medications, etc.)
},
"output_info": {
# Prediction targets (labels, values)
}
}
Integration with Models
Tasks define the input/output contract for models:
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
from pyhealth.models import Transformer
# 1. Create task-specific dataset
dataset = MIMIC4Dataset(root="/path/to/data")
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
# 2. Model automatically adapts to task schema
model = Transformer(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary", # matches task output
)
Best Practices
-
Match task to clinical question: Choose predefined tasks when available for standardized benchmarking
-
Consider temporal windows: Ensure sufficient history for meaningful predictions
-
Handle class imbalance: Many clinical outcomes are rare (mortality, readmission)
-
Validate clinical relevance: Ensure prediction windows align with clinical decision-making timelines
-
Use appropriate metrics: Different tasks require different evaluation metrics (AUROC for binary, macro-F1 for multi-class)
-
Document exclusion criteria: Track which patients/visits are filtered and why
-
Preserve patient privacy: Always use de-identified data and follow HIPAA/GDPR guidelines