Initial commit
This commit is contained in:
178
skills/pyhealth/references/datasets.md
Normal file
178
skills/pyhealth/references/datasets.md
Normal file
@@ -0,0 +1,178 @@
|
||||
# PyHealth Datasets and Data Structures
|
||||
|
||||
## Core Data Structures
|
||||
|
||||
### Event
|
||||
Individual medical occurrences with attributes including:
|
||||
- **code**: Medical code (diagnosis, medication, procedure, lab test)
|
||||
- **vocabulary**: Coding system (ICD-9-CM, NDC, LOINC, etc.)
|
||||
- **timestamp**: Event occurrence time
|
||||
- **value**: Numeric value (for labs, vital signs)
|
||||
- **unit**: Measurement unit
|
||||
|
||||
### Patient
|
||||
Collection of events organized chronologically across visits. Each patient contains:
|
||||
- **patient_id**: Unique identifier
|
||||
- **birth_datetime**: Date of birth
|
||||
- **gender**: Patient gender
|
||||
- **ethnicity**: Patient ethnicity
|
||||
- **visits**: List of visit objects
|
||||
|
||||
### Visit
|
||||
Healthcare encounter containing:
|
||||
- **visit_id**: Unique identifier
|
||||
- **encounter_time**: Visit timestamp
|
||||
- **discharge_time**: Discharge timestamp
|
||||
- **visit_type**: Type of encounter (inpatient, outpatient, emergency)
|
||||
- **events**: List of events during this visit
|
||||
|
||||
## BaseDataset Class
|
||||
|
||||
**Key Methods:**
|
||||
- `get_patient(patient_id)`: Retrieve single patient record
|
||||
- `iter_patients()`: Iterate through all patients
|
||||
- `stats()`: Get dataset statistics (patients, visits, events)
|
||||
- `set_task(task_fn)`: Define prediction task
|
||||
|
||||
## Available Datasets
|
||||
|
||||
### Electronic Health Record (EHR) Datasets
|
||||
|
||||
**MIMIC-III Dataset** (`MIMIC3Dataset`)
|
||||
- Intensive care unit data from Beth Israel Deaconess Medical Center
|
||||
- 40,000+ critical care patients
|
||||
- Diagnoses, procedures, medications, lab results
|
||||
- Usage: `from pyhealth.datasets import MIMIC3Dataset`
|
||||
|
||||
**MIMIC-IV Dataset** (`MIMIC4Dataset`)
|
||||
- Updated version with 70,000+ patients
|
||||
- Improved data quality and coverage
|
||||
- Enhanced demographic and clinical detail
|
||||
- Usage: `from pyhealth.datasets import MIMIC4Dataset`
|
||||
|
||||
**eICU Dataset** (`eICUDataset`)
|
||||
- Multi-center critical care database
|
||||
- 200,000+ admissions from 200+ hospitals
|
||||
- Standardized ICU data across facilities
|
||||
- Usage: `from pyhealth.datasets import eICUDataset`
|
||||
|
||||
**OMOP Dataset** (`OMOPDataset`)
|
||||
- Observational Medical Outcomes Partnership format
|
||||
- Standardized common data model
|
||||
- Interoperability across healthcare systems
|
||||
- Usage: `from pyhealth.datasets import OMOPDataset`
|
||||
|
||||
**EHRShot Dataset** (`EHRShotDataset`)
|
||||
- Benchmark dataset for few-shot learning
|
||||
- Specialized for testing model generalization
|
||||
- Usage: `from pyhealth.datasets import EHRShotDataset`
|
||||
|
||||
### Physiological Signal Datasets
|
||||
|
||||
**Sleep EEG Datasets:**
|
||||
- `SleepEDFDataset`: Sleep-EDF database for sleep staging
|
||||
- `SHHSDataset`: Sleep Heart Health Study data
|
||||
- `ISRUCDataset`: ISRUC-Sleep database
|
||||
|
||||
**Temple University EEG Datasets:**
|
||||
- `TUEVDataset`: Abnormal EEG events detection
|
||||
- `TUABDataset`: Abnormal/normal EEG classification
|
||||
- `TUSZDataset`: Seizure detection
|
||||
|
||||
**All signal datasets support:**
|
||||
- Multi-channel EEG signals
|
||||
- Standardized sampling rates
|
||||
- Expert annotations
|
||||
- Sleep stage or abnormality labels
|
||||
|
||||
### Medical Imaging Datasets
|
||||
|
||||
**COVID-19 CXR Dataset** (`COVID19CXRDataset`)
|
||||
- Chest X-ray images for COVID-19 classification
|
||||
- Multi-class labels (COVID-19, pneumonia, normal)
|
||||
- Usage: `from pyhealth.datasets import COVID19CXRDataset`
|
||||
|
||||
### Text-Based Datasets
|
||||
|
||||
**Medical Transcriptions Dataset** (`MedicalTranscriptionsDataset`)
|
||||
- Clinical notes and transcriptions
|
||||
- Medical specialty classification
|
||||
- Text-based prediction tasks
|
||||
- Usage: `from pyhealth.datasets import MedicalTranscriptionsDataset`
|
||||
|
||||
**Cardiology Dataset** (`CardiologyDataset`)
|
||||
- Cardiac patient records
|
||||
- Cardiovascular disease prediction
|
||||
- Usage: `from pyhealth.datasets import CardiologyDataset`
|
||||
|
||||
### Preprocessed Datasets
|
||||
|
||||
**MIMIC Extract Dataset** (`MIMICExtractDataset`)
|
||||
- Pre-extracted MIMIC features
|
||||
- Ready-to-use benchmarking data
|
||||
- Reduced preprocessing requirements
|
||||
- Usage: `from pyhealth.datasets import MIMICExtractDataset`
|
||||
|
||||
## SampleDataset Class
|
||||
|
||||
Converts raw datasets into task-specific formatted samples.
|
||||
|
||||
**Purpose:** Transform patient-level data into model-ready input/output pairs
|
||||
|
||||
**Key Attributes:**
|
||||
- `input_schema`: Defines input data structure
|
||||
- `output_schema`: Defines target labels/predictions
|
||||
- `samples`: List of processed samples
|
||||
|
||||
**Usage Pattern:**
|
||||
```python
|
||||
# After setting task on BaseDataset
|
||||
sample_dataset = dataset.set_task(task_fn)
|
||||
```
|
||||
|
||||
## Data Splitting Functions
|
||||
|
||||
**Patient-Level Split** (`split_by_patient`)
|
||||
- Ensures no patient appears in multiple splits
|
||||
- Prevents data leakage
|
||||
- Recommended for clinical prediction tasks
|
||||
|
||||
**Visit-Level Split** (`split_by_visit`)
|
||||
- Splits by individual visits
|
||||
- Allows same patient across splits (use cautiously)
|
||||
|
||||
**Sample-Level Split** (`split_by_sample`)
|
||||
- Random sample splitting
|
||||
- Most flexible but may cause leakage
|
||||
|
||||
**Parameters:**
|
||||
- `dataset`: SampleDataset to split
|
||||
- `ratios`: Tuple of split ratios (e.g., [0.7, 0.1, 0.2])
|
||||
- `seed`: Random seed for reproducibility
|
||||
|
||||
## Common Workflow
|
||||
|
||||
```python
|
||||
from pyhealth.datasets import MIMIC4Dataset
|
||||
from pyhealth.tasks import mortality_prediction_mimic4_fn
|
||||
from pyhealth.datasets import split_by_patient
|
||||
|
||||
# 1. Load dataset
|
||||
dataset = MIMIC4Dataset(root="/path/to/data")
|
||||
|
||||
# 2. Set prediction task
|
||||
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
|
||||
|
||||
# 3. Split data
|
||||
train, val, test = split_by_patient(sample_dataset, [0.7, 0.1, 0.2])
|
||||
|
||||
# 4. Get statistics
|
||||
print(dataset.stats())
|
||||
```
|
||||
|
||||
## Performance Notes
|
||||
|
||||
- PyHealth is **3x faster than pandas** for healthcare data processing
|
||||
- Optimized for large-scale EHR datasets
|
||||
- Memory-efficient patient iteration
|
||||
- Vectorized operations for feature extraction
|
||||
284
skills/pyhealth/references/medical_coding.md
Normal file
284
skills/pyhealth/references/medical_coding.md
Normal file
@@ -0,0 +1,284 @@
|
||||
# PyHealth Medical Code Translation
|
||||
|
||||
## Overview
|
||||
|
||||
Healthcare data uses multiple coding systems and standards. PyHealth's MedCode module enables translation and mapping between medical coding systems through ontology lookups and cross-system mappings.
|
||||
|
||||
## Core Classes
|
||||
|
||||
### InnerMap
|
||||
Handles within-system ontology lookups and hierarchical navigation.
|
||||
|
||||
**Key Capabilities:**
|
||||
- Code lookup with attributes (names, descriptions)
|
||||
- Ancestor/descendant hierarchy traversal
|
||||
- Code standardization and conversion
|
||||
- Parent-child relationship navigation
|
||||
|
||||
### CrossMap
|
||||
Manages cross-system mappings between different coding standards.
|
||||
|
||||
**Key Capabilities:**
|
||||
- Translation between coding systems
|
||||
- Many-to-many relationship handling
|
||||
- Hierarchical level specification (for medications)
|
||||
- Bidirectional mapping support
|
||||
|
||||
## Supported Coding Systems
|
||||
|
||||
### Diagnosis Codes
|
||||
|
||||
**ICD-9-CM (International Classification of Diseases, 9th Revision, Clinical Modification)**
|
||||
- Legacy diagnosis coding system
|
||||
- Hierarchical structure with 3-5 digit codes
|
||||
- Used in US healthcare pre-2015
|
||||
- Usage: `from pyhealth.medcode import InnerMap`
|
||||
- `icd9_map = InnerMap.load("ICD9CM")`
|
||||
|
||||
**ICD-10-CM (International Classification of Diseases, 10th Revision, Clinical Modification)**
|
||||
- Current diagnosis coding standard
|
||||
- Alphanumeric codes (3-7 characters)
|
||||
- More granular than ICD-9
|
||||
- Usage: `from pyhealth.medcode import InnerMap`
|
||||
- `icd10_map = InnerMap.load("ICD10CM")`
|
||||
|
||||
**CCSCM (Clinical Classifications Software for ICD-CM)**
|
||||
- Groups ICD codes into clinically meaningful categories
|
||||
- Reduces dimensionality for analysis
|
||||
- Single-level and multi-level hierarchies
|
||||
- Usage: `from pyhealth.medcode import CrossMap`
|
||||
- `icd_to_ccs = CrossMap.load("ICD9CM", "CCSCM")`
|
||||
|
||||
### Procedure Codes
|
||||
|
||||
**ICD-9-PROC (ICD-9 Procedure Codes)**
|
||||
- Inpatient procedure classification
|
||||
- 3-4 digit numeric codes
|
||||
- Legacy system (pre-2015)
|
||||
- Usage: `from pyhealth.medcode import InnerMap`
|
||||
- `icd9proc_map = InnerMap.load("ICD9PROC")`
|
||||
|
||||
**ICD-10-PROC (ICD-10 Procedure Coding System)**
|
||||
- Current procedural coding standard
|
||||
- 7-character alphanumeric codes
|
||||
- More detailed than ICD-9-PROC
|
||||
- Usage: `from pyhealth.medcode import InnerMap`
|
||||
- `icd10proc_map = InnerMap.load("ICD10PROC")`
|
||||
|
||||
**CCSPROC (Clinical Classifications Software for Procedures)**
|
||||
- Groups procedure codes into categories
|
||||
- Simplifies procedure analysis
|
||||
- Usage: `from pyhealth.medcode import CrossMap`
|
||||
- `proc_to_ccs = CrossMap.load("ICD9PROC", "CCSPROC")`
|
||||
|
||||
### Medication Codes
|
||||
|
||||
**NDC (National Drug Code)**
|
||||
- US FDA drug identification system
|
||||
- 10 or 11-digit codes
|
||||
- Product-level specificity (manufacturer, strength, package)
|
||||
- Usage: `from pyhealth.medcode import InnerMap`
|
||||
- `ndc_map = InnerMap.load("NDC")`
|
||||
|
||||
**RxNorm**
|
||||
- Standardized drug terminology
|
||||
- Normalized drug names and relationships
|
||||
- Links multiple drug vocabularies
|
||||
- Usage: `from pyhealth.medcode import CrossMap`
|
||||
- `ndc_to_rxnorm = CrossMap.load("NDC", "RXNORM")`
|
||||
|
||||
**ATC (Anatomical Therapeutic Chemical Classification)**
|
||||
- WHO drug classification system
|
||||
- 5-level hierarchy:
|
||||
- **Level 1**: Anatomical main group (1 letter)
|
||||
- **Level 2**: Therapeutic subgroup (2 digits)
|
||||
- **Level 3**: Pharmacological subgroup (1 letter)
|
||||
- **Level 4**: Chemical subgroup (1 letter)
|
||||
- **Level 5**: Chemical substance (2 digits)
|
||||
- Example: "C03CA01" = Furosemide
|
||||
- C = Cardiovascular system
|
||||
- C03 = Diuretics
|
||||
- C03C = High-ceiling diuretics
|
||||
- C03CA = Sulfonamides
|
||||
- C03CA01 = Furosemide
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.medcode import CrossMap
|
||||
ndc_to_atc = CrossMap.load("NDC", "ATC")
|
||||
atc_codes = ndc_to_atc.map("00074-3799-13", level=3) # Get ATC level 3
|
||||
```
|
||||
|
||||
## Common Operations
|
||||
|
||||
### InnerMap Operations
|
||||
|
||||
**1. Code Lookup**
|
||||
```python
|
||||
from pyhealth.medcode import InnerMap
|
||||
|
||||
icd9_map = InnerMap.load("ICD9CM")
|
||||
info = icd9_map.lookup("428.0") # Heart failure
|
||||
# Returns: name, description, additional attributes
|
||||
```
|
||||
|
||||
**2. Ancestor Traversal**
|
||||
```python
|
||||
# Get all parent codes in hierarchy
|
||||
ancestors = icd9_map.get_ancestors("428.0")
|
||||
# Returns: ["428", "420-429", "390-459"]
|
||||
```
|
||||
|
||||
**3. Descendant Traversal**
|
||||
```python
|
||||
# Get all child codes
|
||||
descendants = icd9_map.get_descendants("428")
|
||||
# Returns: ["428.0", "428.1", "428.2", ...]
|
||||
```
|
||||
|
||||
**4. Code Standardization**
|
||||
```python
|
||||
# Normalize code format
|
||||
standard_code = icd9_map.standardize("4280") # Returns "428.0"
|
||||
```
|
||||
|
||||
### CrossMap Operations
|
||||
|
||||
**1. Direct Translation**
|
||||
```python
|
||||
from pyhealth.medcode import CrossMap
|
||||
|
||||
# ICD-9-CM to CCS
|
||||
icd_to_ccs = CrossMap.load("ICD9CM", "CCSCM")
|
||||
ccs_codes = icd_to_ccs.map("82101") # Coronary atherosclerosis
|
||||
# Returns: ["101"] # CCS category for coronary atherosclerosis
|
||||
```
|
||||
|
||||
**2. Hierarchical Drug Mapping**
|
||||
```python
|
||||
# NDC to ATC at different levels
|
||||
ndc_to_atc = CrossMap.load("NDC", "ATC")
|
||||
|
||||
# Get specific ATC level
|
||||
atc_level_1 = ndc_to_atc.map("00074-3799-13", level=1) # Anatomical group
|
||||
atc_level_3 = ndc_to_atc.map("00074-3799-13", level=3) # Pharmacological
|
||||
atc_level_5 = ndc_to_atc.map("00074-3799-13", level=5) # Chemical substance
|
||||
```
|
||||
|
||||
**3. Bidirectional Mapping**
|
||||
```python
|
||||
# Map in either direction
|
||||
rxnorm_to_ndc = CrossMap.load("RXNORM", "NDC")
|
||||
ndc_codes = rxnorm_to_ndc.map("197381") # Get all NDC codes for RxNorm
|
||||
```
|
||||
|
||||
## Workflow Examples
|
||||
|
||||
### Example 1: Standardize and Group Diagnoses
|
||||
```python
|
||||
from pyhealth.medcode import InnerMap, CrossMap
|
||||
|
||||
# Load maps
|
||||
icd9_map = InnerMap.load("ICD9CM")
|
||||
icd_to_ccs = CrossMap.load("ICD9CM", "CCSCM")
|
||||
|
||||
# Process diagnosis codes
|
||||
raw_codes = ["4280", "428.0", "42800"]
|
||||
|
||||
standardized = [icd9_map.standardize(code) for code in raw_codes]
|
||||
# All become "428.0"
|
||||
|
||||
ccs_categories = [icd_to_ccs.map(code)[0] for code in standardized]
|
||||
# All map to CCS category "108" (Heart failure)
|
||||
```
|
||||
|
||||
### Example 2: Drug Classification Analysis
|
||||
```python
|
||||
from pyhealth.medcode import CrossMap
|
||||
|
||||
# Map NDC to ATC for drug class analysis
|
||||
ndc_to_atc = CrossMap.load("NDC", "ATC")
|
||||
|
||||
patient_drugs = ["00074-3799-13", "00074-7286-01", "00456-0765-01"]
|
||||
|
||||
# Get therapeutic subgroups (ATC level 2)
|
||||
drug_classes = []
|
||||
for ndc in patient_drugs:
|
||||
atc_codes = ndc_to_atc.map(ndc, level=2)
|
||||
if atc_codes:
|
||||
drug_classes.append(atc_codes[0])
|
||||
|
||||
# Analyze drug class distribution
|
||||
```
|
||||
|
||||
### Example 3: ICD-9 to ICD-10 Migration
|
||||
```python
|
||||
from pyhealth.medcode import CrossMap
|
||||
|
||||
# Load ICD-9 to ICD-10 mapping
|
||||
icd9_to_icd10 = CrossMap.load("ICD9CM", "ICD10CM")
|
||||
|
||||
# Convert historical ICD-9 codes
|
||||
icd9_code = "428.0"
|
||||
icd10_codes = icd9_to_icd10.map(icd9_code)
|
||||
# Returns: ["I50.9", "I50.1", ...] # Multiple possible ICD-10 codes
|
||||
|
||||
# Handle one-to-many mappings
|
||||
for icd10_code in icd10_codes:
|
||||
print(f"ICD-9 {icd9_code} -> ICD-10 {icd10_code}")
|
||||
```
|
||||
|
||||
## Integration with Datasets
|
||||
|
||||
Medical code translation integrates seamlessly with PyHealth datasets:
|
||||
|
||||
```python
|
||||
from pyhealth.datasets import MIMIC4Dataset
|
||||
from pyhealth.medcode import CrossMap
|
||||
|
||||
# Load dataset
|
||||
dataset = MIMIC4Dataset(root="/path/to/data")
|
||||
|
||||
# Load code mapping
|
||||
icd_to_ccs = CrossMap.load("ICD10CM", "CCSCM")
|
||||
|
||||
# Process patient diagnoses
|
||||
for patient in dataset.iter_patients():
|
||||
for visit in patient.visits:
|
||||
diagnosis_events = [e for e in visit.events if e.vocabulary == "ICD10CM"]
|
||||
|
||||
for event in diagnosis_events:
|
||||
ccs_codes = icd_to_ccs.map(event.code)
|
||||
print(f"Diagnosis {event.code} -> CCS {ccs_codes}")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### Clinical Research
|
||||
- Standardize diagnoses across different coding systems
|
||||
- Group related conditions for cohort identification
|
||||
- Harmonize multi-site studies with different standards
|
||||
|
||||
### Drug Safety Analysis
|
||||
- Classify medications by therapeutic class
|
||||
- Identify drug-drug interactions at class level
|
||||
- Analyze polypharmacy patterns
|
||||
|
||||
### Healthcare Analytics
|
||||
- Reduce diagnosis/procedure dimensionality
|
||||
- Create meaningful clinical categories
|
||||
- Enable longitudinal analysis across coding system changes
|
||||
|
||||
### Machine Learning
|
||||
- Create consistent feature representations
|
||||
- Handle vocabulary mismatch in training/test data
|
||||
- Generate hierarchical embeddings
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always standardize codes** before mapping to ensure consistent format
|
||||
2. **Handle one-to-many mappings** appropriately (some codes map to multiple targets)
|
||||
3. **Specify ATC level** explicitly when mapping drugs to avoid ambiguity
|
||||
4. **Use CCS categories** to reduce diagnosis/procedure dimensionality
|
||||
5. **Validate mappings** as some codes may not have direct translations
|
||||
6. **Document code versions** (ICD-9 vs ICD-10) to maintain data provenance
|
||||
594
skills/pyhealth/references/models.md
Normal file
594
skills/pyhealth/references/models.md
Normal file
@@ -0,0 +1,594 @@
|
||||
# PyHealth Models
|
||||
|
||||
## Overview
|
||||
|
||||
PyHealth provides 33+ models for healthcare prediction tasks, ranging from simple baselines to state-of-the-art deep learning architectures. Models are organized into general-purpose architectures and healthcare-specific models.
|
||||
|
||||
## Model Base Class
|
||||
|
||||
All models inherit from `BaseModel` with standard PyTorch functionality:
|
||||
|
||||
**Key Attributes:**
|
||||
- `dataset`: Associated SampleDataset
|
||||
- `feature_keys`: Input features to use (e.g., ["diagnoses", "medications"])
|
||||
- `mode`: Task type ("binary", "multiclass", "multilabel", "regression")
|
||||
- `embedding_dim`: Feature embedding dimension
|
||||
- `device`: Computation device (CPU/GPU)
|
||||
|
||||
**Key Methods:**
|
||||
- `forward()`: Model forward pass
|
||||
- `train_step()`: Single training iteration
|
||||
- `eval_step()`: Single evaluation iteration
|
||||
- `save()`: Save model checkpoint
|
||||
- `load()`: Load model checkpoint
|
||||
|
||||
## General-Purpose Models
|
||||
|
||||
### Baseline Models
|
||||
|
||||
**Logistic Regression** (`LogisticRegression`)
|
||||
- Linear classifier with mean pooling
|
||||
- Simple baseline for comparison
|
||||
- Fast training and inference
|
||||
- Good for interpretability
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import LogisticRegression
|
||||
|
||||
model = LogisticRegression(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary"
|
||||
)
|
||||
```
|
||||
|
||||
**Multi-Layer Perceptron** (`MLP`)
|
||||
- Feedforward neural network
|
||||
- Configurable hidden layers
|
||||
- Supports mean/sum/max pooling
|
||||
- Good baseline for structured data
|
||||
|
||||
**Parameters:**
|
||||
- `hidden_dim`: Hidden layer size
|
||||
- `num_layers`: Number of hidden layers
|
||||
- `dropout`: Dropout rate
|
||||
- `pooling`: Aggregation method ("mean", "sum", "max")
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import MLP
|
||||
|
||||
model = MLP(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
hidden_dim=128,
|
||||
num_layers=3,
|
||||
dropout=0.5
|
||||
)
|
||||
```
|
||||
|
||||
### Convolutional Neural Networks
|
||||
|
||||
**CNN** (`CNN`)
|
||||
- Convolutional layers for pattern detection
|
||||
- Effective for sequential and spatial data
|
||||
- Captures local temporal patterns
|
||||
- Parameter efficient
|
||||
|
||||
**Architecture:**
|
||||
- Multiple 1D convolutional layers
|
||||
- Max pooling for dimension reduction
|
||||
- Fully connected output layers
|
||||
|
||||
**Parameters:**
|
||||
- `num_filters`: Number of convolutional filters
|
||||
- `kernel_size`: Convolution kernel size
|
||||
- `num_layers`: Number of conv layers
|
||||
- `dropout`: Dropout rate
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import CNN
|
||||
|
||||
model = CNN(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
num_filters=64,
|
||||
kernel_size=3,
|
||||
num_layers=3
|
||||
)
|
||||
```
|
||||
|
||||
**Temporal Convolutional Networks** (`TCN`)
|
||||
- Dilated convolutions for long-range dependencies
|
||||
- Causal convolutions (no future information leakage)
|
||||
- Efficient for long sequences
|
||||
- Good for time-series prediction
|
||||
|
||||
**Advantages:**
|
||||
- Captures long-term dependencies
|
||||
- Parallelizable (faster than RNNs)
|
||||
- Stable gradients
|
||||
|
||||
### Recurrent Neural Networks
|
||||
|
||||
**RNN** (`RNN`)
|
||||
- Basic recurrent architecture
|
||||
- Supports LSTM, GRU, RNN variants
|
||||
- Sequential processing
|
||||
- Captures temporal dependencies
|
||||
|
||||
**Parameters:**
|
||||
- `rnn_type`: "LSTM", "GRU", or "RNN"
|
||||
- `hidden_dim`: Hidden state dimension
|
||||
- `num_layers`: Number of recurrent layers
|
||||
- `dropout`: Dropout rate
|
||||
- `bidirectional`: Use bidirectional RNN
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import RNN
|
||||
|
||||
model = RNN(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
rnn_type="LSTM",
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
bidirectional=True
|
||||
)
|
||||
```
|
||||
|
||||
**Best for:**
|
||||
- Sequential clinical events
|
||||
- Temporal pattern learning
|
||||
- Variable-length sequences
|
||||
|
||||
### Transformer Models
|
||||
|
||||
**Transformer** (`Transformer`)
|
||||
- Self-attention mechanism
|
||||
- Parallel processing of sequences
|
||||
- State-of-the-art performance
|
||||
- Effective for long-range dependencies
|
||||
|
||||
**Architecture:**
|
||||
- Multi-head self-attention
|
||||
- Position embeddings
|
||||
- Feed-forward networks
|
||||
- Layer normalization
|
||||
|
||||
**Parameters:**
|
||||
- `num_heads`: Number of attention heads
|
||||
- `num_layers`: Number of transformer layers
|
||||
- `hidden_dim`: Hidden dimension
|
||||
- `dropout`: Dropout rate
|
||||
- `max_seq_length`: Maximum sequence length
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import Transformer
|
||||
|
||||
model = Transformer(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
num_heads=8,
|
||||
num_layers=6,
|
||||
hidden_dim=256,
|
||||
dropout=0.1
|
||||
)
|
||||
```
|
||||
|
||||
**TransformersModel** (`TransformersModel`)
|
||||
- Integration with HuggingFace transformers
|
||||
- Pre-trained language models for clinical text
|
||||
- Fine-tuning for healthcare tasks
|
||||
- Examples: BERT, RoBERTa, BioClinicalBERT
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import TransformersModel
|
||||
|
||||
model = TransformersModel(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["text"],
|
||||
mode="multiclass",
|
||||
pretrained_model="emilyalsentzer/Bio_ClinicalBERT"
|
||||
)
|
||||
```
|
||||
|
||||
### Graph Neural Networks
|
||||
|
||||
**GNN** (`GNN`)
|
||||
- Graph-based learning
|
||||
- Models relationships between entities
|
||||
- Supports GAT (Graph Attention) and GCN (Graph Convolutional)
|
||||
|
||||
**Use Cases:**
|
||||
- Drug-drug interactions
|
||||
- Patient similarity networks
|
||||
- Knowledge graph integration
|
||||
- Comorbidity relationships
|
||||
|
||||
**Parameters:**
|
||||
- `gnn_type`: "GAT" or "GCN"
|
||||
- `hidden_dim`: Hidden dimension
|
||||
- `num_layers`: Number of GNN layers
|
||||
- `dropout`: Dropout rate
|
||||
- `num_heads`: Attention heads (for GAT)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import GNN
|
||||
|
||||
model = GNN(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="multilabel",
|
||||
gnn_type="GAT",
|
||||
hidden_dim=128,
|
||||
num_layers=3,
|
||||
num_heads=4
|
||||
)
|
||||
```
|
||||
|
||||
## Healthcare-Specific Models
|
||||
|
||||
### Interpretable Clinical Models
|
||||
|
||||
**RETAIN** (`RETAIN`)
|
||||
- Reverse time attention mechanism
|
||||
- Highly interpretable predictions
|
||||
- Visit-level and event-level attention
|
||||
- Identifies influential clinical events
|
||||
|
||||
**Key Features:**
|
||||
- Two-level attention (visits and features)
|
||||
- Temporal decay modeling
|
||||
- Clinically meaningful explanations
|
||||
- Published in NeurIPS 2016
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import RETAIN
|
||||
|
||||
model = RETAIN(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
hidden_dim=128
|
||||
)
|
||||
|
||||
# Get attention weights for interpretation
|
||||
outputs = model(batch)
|
||||
visit_attention = outputs["visit_attention"]
|
||||
feature_attention = outputs["feature_attention"]
|
||||
```
|
||||
|
||||
**Best for:**
|
||||
- Mortality prediction
|
||||
- Readmission prediction
|
||||
- Clinical risk scoring
|
||||
- Interpretable predictions
|
||||
|
||||
**AdaCare** (`AdaCare`)
|
||||
- Adaptive care model with feature calibration
|
||||
- Disease-specific attention
|
||||
- Handles irregular time intervals
|
||||
- Interpretable feature importance
|
||||
|
||||
**ConCare** (`ConCare`)
|
||||
- Cross-visit convolutional attention
|
||||
- Temporal convolutional feature extraction
|
||||
- Multi-level attention mechanism
|
||||
- Good for longitudinal EHR modeling
|
||||
|
||||
### Medication Recommendation Models
|
||||
|
||||
**GAMENet** (`GAMENet`)
|
||||
- Graph-based medication recommendation
|
||||
- Drug-drug interaction modeling
|
||||
- Memory network for patient history
|
||||
- Multi-hop reasoning
|
||||
|
||||
**Architecture:**
|
||||
- Drug knowledge graph
|
||||
- Memory-augmented neural network
|
||||
- DDI-aware prediction
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import GAMENet
|
||||
|
||||
model = GAMENet(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="multilabel",
|
||||
embedding_dim=128,
|
||||
ddi_adj_path="/path/to/ddi_adjacency_matrix.pkl"
|
||||
)
|
||||
```
|
||||
|
||||
**MICRON** (`MICRON`)
|
||||
- Medication recommendation with DDI constraints
|
||||
- Interaction-aware predictions
|
||||
- Safety-focused drug selection
|
||||
|
||||
**SafeDrug** (`SafeDrug`)
|
||||
- Safety-aware drug recommendation
|
||||
- Molecular structure integration
|
||||
- DDI constraint optimization
|
||||
- Balances efficacy and safety
|
||||
|
||||
**Key Features:**
|
||||
- Molecular graph encoding
|
||||
- DDI graph neural network
|
||||
- Reinforcement learning for safety
|
||||
- Published in KDD 2021
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import SafeDrug
|
||||
|
||||
model = SafeDrug(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="multilabel",
|
||||
ddi_adj_path="/path/to/ddi_matrix.pkl",
|
||||
molecule_path="/path/to/molecule_graphs.pkl"
|
||||
)
|
||||
```
|
||||
|
||||
**MoleRec** (`MoleRec`)
|
||||
- Molecular-level drug recommendations
|
||||
- Sub-structure reasoning
|
||||
- Fine-grained medication selection
|
||||
|
||||
### Disease Progression Models
|
||||
|
||||
**StageNet** (`StageNet`)
|
||||
- Disease stage-aware prediction
|
||||
- Learns clinical stages automatically
|
||||
- Stage-adaptive feature extraction
|
||||
- Effective for chronic disease monitoring
|
||||
|
||||
**Architecture:**
|
||||
- Stage-aware LSTM
|
||||
- Dynamic stage transitions
|
||||
- Time-decay mechanism
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import StageNet
|
||||
|
||||
model = StageNet(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
hidden_dim=128,
|
||||
num_stages=3,
|
||||
chunk_size=128
|
||||
)
|
||||
```
|
||||
|
||||
**Best for:**
|
||||
- ICU mortality prediction
|
||||
- Chronic disease progression
|
||||
- Time-varying risk assessment
|
||||
|
||||
**Deepr** (`Deepr`)
|
||||
- Deep recurrent architecture
|
||||
- Medical concept embeddings
|
||||
- Temporal pattern learning
|
||||
- Published in JAMIA
|
||||
|
||||
### Advanced Sequential Models
|
||||
|
||||
**Agent** (`Agent`)
|
||||
- Reinforcement learning-based
|
||||
- Treatment recommendation
|
||||
- Action-value optimization
|
||||
- Policy learning for sequential decisions
|
||||
|
||||
**GRASP** (`GRASP`)
|
||||
- Graph-based sequence patterns
|
||||
- Structural event relationships
|
||||
- Hierarchical representation learning
|
||||
|
||||
**SparcNet** (`SparcNet`)
|
||||
- Sparse clinical networks
|
||||
- Efficient feature selection
|
||||
- Reduced computational cost
|
||||
- Interpretable predictions
|
||||
|
||||
**ContraWR** (`ContraWR`)
|
||||
- Contrastive learning approach
|
||||
- Self-supervised pre-training
|
||||
- Robust representations
|
||||
- Limited labeled data scenarios
|
||||
|
||||
### Medical Entity Linking
|
||||
|
||||
**MedLink** (`MedLink`)
|
||||
- Medical entity linking to knowledge bases
|
||||
- Clinical concept normalization
|
||||
- UMLS integration
|
||||
- Entity disambiguation
|
||||
|
||||
### Generative Models
|
||||
|
||||
**GAN** (`GAN`)
|
||||
- Generative Adversarial Networks
|
||||
- Synthetic EHR data generation
|
||||
- Privacy-preserving data sharing
|
||||
- Augmentation for rare conditions
|
||||
|
||||
**VAE** (`VAE`)
|
||||
- Variational Autoencoder
|
||||
- Patient representation learning
|
||||
- Anomaly detection
|
||||
- Latent space exploration
|
||||
|
||||
### Social Determinants of Health
|
||||
|
||||
**SDOH** (`SDOH`)
|
||||
- Social determinants integration
|
||||
- Multi-modal prediction
|
||||
- Addresses health disparities
|
||||
- Combines clinical and social data
|
||||
|
||||
## Model Selection Guidelines
|
||||
|
||||
### By Task Type
|
||||
|
||||
**Binary Classification** (Mortality, Readmission)
|
||||
- Start with: Logistic Regression (baseline)
|
||||
- Standard: RNN, Transformer
|
||||
- Interpretable: RETAIN, AdaCare
|
||||
- Advanced: StageNet
|
||||
|
||||
**Multi-Label Classification** (Drug Recommendation)
|
||||
- Standard: CNN, RNN
|
||||
- Healthcare-specific: GAMENet, SafeDrug, MICRON, MoleRec
|
||||
- Graph-based: GNN
|
||||
|
||||
**Regression** (Length of Stay)
|
||||
- Start with: MLP (baseline)
|
||||
- Sequential: RNN, TCN
|
||||
- Advanced: Transformer
|
||||
|
||||
**Multi-Class Classification** (Medical Coding, Specialty)
|
||||
- Standard: CNN, RNN, Transformer
|
||||
- Text-based: TransformersModel (BERT variants)
|
||||
|
||||
### By Data Type
|
||||
|
||||
**Sequential Events** (Diagnoses, Medications, Procedures)
|
||||
- RNN, LSTM, GRU
|
||||
- Transformer
|
||||
- RETAIN, AdaCare, ConCare
|
||||
|
||||
**Time-Series Signals** (EEG, ECG)
|
||||
- CNN, TCN
|
||||
- RNN
|
||||
- Transformer
|
||||
|
||||
**Text** (Clinical Notes)
|
||||
- TransformersModel (ClinicalBERT, BioBERT)
|
||||
- CNN for shorter text
|
||||
- RNN for sequential text
|
||||
|
||||
**Graphs** (Drug Interactions, Patient Networks)
|
||||
- GNN (GAT, GCN)
|
||||
- GAMENet, SafeDrug
|
||||
|
||||
**Images** (X-rays, CT scans)
|
||||
- CNN (ResNet, DenseNet via TransformersModel)
|
||||
- Vision Transformers
|
||||
|
||||
### By Interpretability Needs
|
||||
|
||||
**High Interpretability Required:**
|
||||
- Logistic Regression
|
||||
- RETAIN
|
||||
- AdaCare
|
||||
- SparcNet
|
||||
|
||||
**Moderate Interpretability:**
|
||||
- CNN (filter visualization)
|
||||
- Transformer (attention visualization)
|
||||
- GNN (graph attention)
|
||||
|
||||
**Black-Box Acceptable:**
|
||||
- Deep RNN models
|
||||
- Complex ensembles
|
||||
|
||||
## Training Considerations
|
||||
|
||||
### Hyperparameter Tuning
|
||||
|
||||
**Embedding Dimension:**
|
||||
- Small datasets: 64-128
|
||||
- Large datasets: 128-256
|
||||
- Complex tasks: 256-512
|
||||
|
||||
**Hidden Dimension:**
|
||||
- Proportional to embedding_dim
|
||||
- Typically 1-2x embedding_dim
|
||||
|
||||
**Number of Layers:**
|
||||
- Start with 2-3 layers
|
||||
- Deeper for complex patterns
|
||||
- Watch for overfitting
|
||||
|
||||
**Dropout:**
|
||||
- Start with 0.5
|
||||
- Reduce if underfitting (0.1-0.3)
|
||||
- Increase if overfitting (0.5-0.7)
|
||||
|
||||
### Computational Requirements
|
||||
|
||||
**Memory (GPU):**
|
||||
- CNN: Low to moderate
|
||||
- RNN: Moderate (sequence length dependent)
|
||||
- Transformer: High (quadratic in sequence length)
|
||||
- GNN: Moderate to high (graph size dependent)
|
||||
|
||||
**Training Speed:**
|
||||
- Fastest: Logistic Regression, MLP, CNN
|
||||
- Moderate: RNN, GNN
|
||||
- Slower: Transformer (but parallelizable)
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Start with simple baselines** (Logistic Regression, MLP)
|
||||
2. **Use appropriate feature keys** based on data availability
|
||||
3. **Match mode to task output** (binary, multiclass, multilabel, regression)
|
||||
4. **Consider interpretability requirements** for clinical deployment
|
||||
5. **Validate on held-out test set** for realistic performance
|
||||
6. **Monitor for overfitting** especially with complex models
|
||||
7. **Use pretrained models** when possible (TransformersModel)
|
||||
8. **Consider computational constraints** for deployment
|
||||
|
||||
## Example Workflow
|
||||
|
||||
```python
|
||||
from pyhealth.datasets import MIMIC4Dataset
|
||||
from pyhealth.tasks import mortality_prediction_mimic4_fn
|
||||
from pyhealth.models import Transformer
|
||||
from pyhealth.trainer import Trainer
|
||||
|
||||
# 1. Prepare data
|
||||
dataset = MIMIC4Dataset(root="/path/to/data")
|
||||
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
|
||||
|
||||
# 2. Initialize model
|
||||
model = Transformer(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications", "procedures"],
|
||||
mode="binary",
|
||||
embedding_dim=128,
|
||||
num_heads=8,
|
||||
num_layers=3,
|
||||
dropout=0.3
|
||||
)
|
||||
|
||||
# 3. Train model
|
||||
trainer = Trainer(model=model)
|
||||
trainer.train(
|
||||
train_dataloader=train_loader,
|
||||
val_dataloader=val_loader,
|
||||
epochs=50,
|
||||
monitor="pr_auc_score",
|
||||
monitor_criterion="max"
|
||||
)
|
||||
|
||||
# 4. Evaluate
|
||||
results = trainer.evaluate(test_loader)
|
||||
print(results)
|
||||
```
|
||||
638
skills/pyhealth/references/preprocessing.md
Normal file
638
skills/pyhealth/references/preprocessing.md
Normal file
@@ -0,0 +1,638 @@
|
||||
# PyHealth Data Preprocessing and Processors
|
||||
|
||||
## Overview
|
||||
|
||||
PyHealth provides comprehensive data processing utilities to transform raw healthcare data into model-ready formats. Processors handle feature extraction, sequence processing, signal transformation, and label preparation.
|
||||
|
||||
## Processor Base Class
|
||||
|
||||
All processors inherit from `Processor` with standard interface:
|
||||
|
||||
**Key Methods:**
|
||||
- `__call__()`: Transform input data
|
||||
- `get_input_info()`: Return processed input schema
|
||||
- `get_output_info()`: Return processed output schema
|
||||
|
||||
## Core Processor Types
|
||||
|
||||
### Feature Processors
|
||||
|
||||
**FeatureProcessor** (`FeatureProcessor`)
|
||||
- Base class for feature extraction
|
||||
- Handles vocabulary building
|
||||
- Embedding preparation
|
||||
- Feature encoding
|
||||
|
||||
**Common Operations:**
|
||||
- Medical code tokenization
|
||||
- Categorical encoding
|
||||
- Feature normalization
|
||||
- Missing value handling
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import FeatureProcessor
|
||||
|
||||
processor = FeatureProcessor(
|
||||
vocabulary="diagnoses",
|
||||
min_freq=5, # Minimum code frequency
|
||||
max_vocab_size=10000
|
||||
)
|
||||
|
||||
processed_features = processor(raw_features)
|
||||
```
|
||||
|
||||
### Sequence Processors
|
||||
|
||||
**SequenceProcessor** (`SequenceProcessor`)
|
||||
- Processes sequential clinical events
|
||||
- Temporal ordering preservation
|
||||
- Sequence padding/truncation
|
||||
- Time gap encoding
|
||||
|
||||
**Key Features:**
|
||||
- Variable-length sequence handling
|
||||
- Temporal feature extraction
|
||||
- Sequence statistics computation
|
||||
|
||||
**Parameters:**
|
||||
- `max_seq_length`: Maximum sequence length (truncate if longer)
|
||||
- `padding`: Padding strategy ("pre" or "post")
|
||||
- `truncating`: Truncation strategy ("pre" or "post")
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import SequenceProcessor
|
||||
|
||||
processor = SequenceProcessor(
|
||||
max_seq_length=100,
|
||||
padding="post",
|
||||
truncating="post"
|
||||
)
|
||||
|
||||
# Process diagnosis sequences
|
||||
processed_seq = processor(diagnosis_sequences)
|
||||
```
|
||||
|
||||
**NestedSequenceProcessor** (`NestedSequenceProcessor`)
|
||||
- Handles hierarchical sequences (e.g., visits containing events)
|
||||
- Two-level processing (visit-level and event-level)
|
||||
- Preserves nested structure
|
||||
|
||||
**Use Cases:**
|
||||
- EHR with visits containing multiple events
|
||||
- Multi-level temporal modeling
|
||||
- Hierarchical attention models
|
||||
|
||||
**Structure:**
|
||||
```python
|
||||
# Input: [[visit1_events], [visit2_events], ...]
|
||||
# Output: Processed nested sequences with proper padding
|
||||
```
|
||||
|
||||
### Numeric Data Processors
|
||||
|
||||
**NestedFloatsProcessor** (`NestedFloatsProcessor`)
|
||||
- Processes nested numeric arrays
|
||||
- Lab values, vital signs, measurements
|
||||
- Multi-level numeric features
|
||||
|
||||
**Operations:**
|
||||
- Normalization
|
||||
- Standardization
|
||||
- Missing value imputation
|
||||
- Outlier handling
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import NestedFloatsProcessor
|
||||
|
||||
processor = NestedFloatsProcessor(
|
||||
normalization="z-score", # or "min-max"
|
||||
fill_missing="mean" # imputation strategy
|
||||
)
|
||||
|
||||
processed_labs = processor(lab_values)
|
||||
```
|
||||
|
||||
**TensorProcessor** (`TensorProcessor`)
|
||||
- Converts data to PyTorch tensors
|
||||
- Type handling (long, float, etc.)
|
||||
- Device placement (CPU/GPU)
|
||||
|
||||
**Parameters:**
|
||||
- `dtype`: Tensor data type
|
||||
- `device`: Computation device
|
||||
|
||||
### Time-Series Processors
|
||||
|
||||
**TimeseriesProcessor** (`TimeseriesProcessor`)
|
||||
- Handles temporal data with timestamps
|
||||
- Time gap computation
|
||||
- Temporal feature engineering
|
||||
- Irregular sampling handling
|
||||
|
||||
**Extracted Features:**
|
||||
- Time since previous event
|
||||
- Time to next event
|
||||
- Event frequency
|
||||
- Temporal patterns
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import TimeseriesProcessor
|
||||
|
||||
processor = TimeseriesProcessor(
|
||||
time_unit="hour", # "day", "hour", "minute"
|
||||
compute_gaps=True,
|
||||
compute_frequency=True
|
||||
)
|
||||
|
||||
processed_ts = processor(timestamps, events)
|
||||
```
|
||||
|
||||
**SignalProcessor** (`SignalProcessor`)
|
||||
- Physiological signal processing
|
||||
- EEG, ECG, PPG signals
|
||||
- Filtering and preprocessing
|
||||
|
||||
**Operations:**
|
||||
- Bandpass filtering
|
||||
- Artifact removal
|
||||
- Segmentation
|
||||
- Feature extraction (frequency, amplitude)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import SignalProcessor
|
||||
|
||||
processor = SignalProcessor(
|
||||
sampling_rate=256, # Hz
|
||||
bandpass_filter=(0.5, 50), # Hz range
|
||||
segment_length=30 # seconds
|
||||
)
|
||||
|
||||
processed_signal = processor(raw_eeg_signal)
|
||||
```
|
||||
|
||||
### Image Processors
|
||||
|
||||
**ImageProcessor** (`ImageProcessor`)
|
||||
- Medical image preprocessing
|
||||
- Normalization and resizing
|
||||
- Augmentation support
|
||||
- Format standardization
|
||||
|
||||
**Operations:**
|
||||
- Resize to standard dimensions
|
||||
- Normalization (mean/std)
|
||||
- Windowing (for CT/MRI)
|
||||
- Data augmentation
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import ImageProcessor
|
||||
|
||||
processor = ImageProcessor(
|
||||
image_size=(224, 224),
|
||||
normalization="imagenet", # or custom mean/std
|
||||
augmentation=True
|
||||
)
|
||||
|
||||
processed_image = processor(raw_image)
|
||||
```
|
||||
|
||||
## Label Processors
|
||||
|
||||
### Binary Classification
|
||||
|
||||
**BinaryLabelProcessor** (`BinaryLabelProcessor`)
|
||||
- Binary classification labels (0/1)
|
||||
- Handles positive/negative classes
|
||||
- Class weighting for imbalance
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import BinaryLabelProcessor
|
||||
|
||||
processor = BinaryLabelProcessor(
|
||||
positive_class=1,
|
||||
class_weight="balanced"
|
||||
)
|
||||
|
||||
processed_labels = processor(raw_labels)
|
||||
```
|
||||
|
||||
### Multi-Class Classification
|
||||
|
||||
**MultiClassLabelProcessor** (`MultiClassLabelProcessor`)
|
||||
- Multi-class classification (mutually exclusive classes)
|
||||
- Label encoding
|
||||
- Class balancing
|
||||
|
||||
**Parameters:**
|
||||
- `num_classes`: Number of classes
|
||||
- `class_weight`: Weighting strategy
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import MultiClassLabelProcessor
|
||||
|
||||
processor = MultiClassLabelProcessor(
|
||||
num_classes=5, # e.g., sleep stages: W, N1, N2, N3, REM
|
||||
class_weight="balanced"
|
||||
)
|
||||
|
||||
processed_labels = processor(raw_labels)
|
||||
```
|
||||
|
||||
### Multi-Label Classification
|
||||
|
||||
**MultiLabelProcessor** (`MultiLabelProcessor`)
|
||||
- Multi-label classification (multiple labels per sample)
|
||||
- Binary encoding for each label
|
||||
- Label co-occurrence handling
|
||||
|
||||
**Use Cases:**
|
||||
- Drug recommendation (multiple drugs)
|
||||
- ICD coding (multiple diagnoses)
|
||||
- Comorbidity prediction
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import MultiLabelProcessor
|
||||
|
||||
processor = MultiLabelProcessor(
|
||||
num_labels=100, # total possible labels
|
||||
threshold=0.5 # prediction threshold
|
||||
)
|
||||
|
||||
processed_labels = processor(raw_label_sets)
|
||||
```
|
||||
|
||||
### Regression
|
||||
|
||||
**RegressionLabelProcessor** (`RegressionLabelProcessor`)
|
||||
- Continuous value prediction
|
||||
- Target scaling and normalization
|
||||
- Outlier handling
|
||||
|
||||
**Use Cases:**
|
||||
- Length of stay prediction
|
||||
- Lab value prediction
|
||||
- Risk score estimation
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import RegressionLabelProcessor
|
||||
|
||||
processor = RegressionLabelProcessor(
|
||||
normalization="z-score", # or "min-max"
|
||||
clip_outliers=True,
|
||||
outlier_std=3 # clip at 3 standard deviations
|
||||
)
|
||||
|
||||
processed_targets = processor(raw_values)
|
||||
```
|
||||
|
||||
## Specialized Processors
|
||||
|
||||
### Text Processing
|
||||
|
||||
**TextProcessor** (`TextProcessor`)
|
||||
- Clinical text preprocessing
|
||||
- Tokenization
|
||||
- Vocabulary building
|
||||
- Sequence encoding
|
||||
|
||||
**Operations:**
|
||||
- Lowercasing
|
||||
- Punctuation removal
|
||||
- Medical abbreviation handling
|
||||
- Token frequency filtering
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import TextProcessor
|
||||
|
||||
processor = TextProcessor(
|
||||
tokenizer="word", # or "sentencepiece", "bpe"
|
||||
lowercase=True,
|
||||
max_vocab_size=50000,
|
||||
min_freq=5
|
||||
)
|
||||
|
||||
processed_text = processor(clinical_notes)
|
||||
```
|
||||
|
||||
### Model-Specific Processors
|
||||
|
||||
**StageNetProcessor** (`StageNetProcessor`)
|
||||
- Specialized preprocessing for StageNet model
|
||||
- Chunk-based sequence processing
|
||||
- Stage-aware feature extraction
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import StageNetProcessor
|
||||
|
||||
processor = StageNetProcessor(
|
||||
chunk_size=128,
|
||||
num_stages=3
|
||||
)
|
||||
|
||||
processed_data = processor(sequential_data)
|
||||
```
|
||||
|
||||
**StageNetTensorProcessor** (`StageNetTensorProcessor`)
|
||||
- Tensor conversion for StageNet
|
||||
- Proper batching and padding
|
||||
- Stage mask generation
|
||||
|
||||
### Raw Data Processing
|
||||
|
||||
**RawProcessor** (`RawProcessor`)
|
||||
- Minimal preprocessing
|
||||
- Pass-through for pre-processed data
|
||||
- Custom preprocessing scenarios
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import RawProcessor
|
||||
|
||||
processor = RawProcessor()
|
||||
processed_data = processor(data) # Minimal transformation
|
||||
```
|
||||
|
||||
## Sample-Level Processing
|
||||
|
||||
**SampleProcessor** (`SampleProcessor`)
|
||||
- Processes complete samples (input + output)
|
||||
- Coordinates multiple processors
|
||||
- End-to-end preprocessing pipeline
|
||||
|
||||
**Workflow:**
|
||||
1. Apply input processors to features
|
||||
2. Apply output processors to labels
|
||||
3. Combine into model-ready samples
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import SampleProcessor
|
||||
|
||||
processor = SampleProcessor(
|
||||
input_processors={
|
||||
"diagnoses": SequenceProcessor(max_seq_length=50),
|
||||
"medications": SequenceProcessor(max_seq_length=30),
|
||||
"labs": NestedFloatsProcessor(normalization="z-score")
|
||||
},
|
||||
output_processor=BinaryLabelProcessor()
|
||||
)
|
||||
|
||||
processed_sample = processor(raw_sample)
|
||||
```
|
||||
|
||||
## Dataset-Level Processing
|
||||
|
||||
**DatasetProcessor** (`DatasetProcessor`)
|
||||
- Processes entire datasets
|
||||
- Batch processing
|
||||
- Parallel processing support
|
||||
- Caching for efficiency
|
||||
|
||||
**Operations:**
|
||||
- Apply processors to all samples
|
||||
- Generate vocabulary from dataset
|
||||
- Compute dataset statistics
|
||||
- Save processed data
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.data import DatasetProcessor
|
||||
|
||||
processor = DatasetProcessor(
|
||||
sample_processor=sample_processor,
|
||||
num_workers=4, # parallel processing
|
||||
cache_dir="/path/to/cache"
|
||||
)
|
||||
|
||||
processed_dataset = processor(raw_dataset)
|
||||
```
|
||||
|
||||
## Common Preprocessing Workflows
|
||||
|
||||
### Workflow 1: EHR Mortality Prediction
|
||||
|
||||
```python
|
||||
from pyhealth.data import (
|
||||
SequenceProcessor,
|
||||
BinaryLabelProcessor,
|
||||
SampleProcessor
|
||||
)
|
||||
|
||||
# Define processors
|
||||
input_processors = {
|
||||
"diagnoses": SequenceProcessor(max_seq_length=50),
|
||||
"medications": SequenceProcessor(max_seq_length=30),
|
||||
"procedures": SequenceProcessor(max_seq_length=20)
|
||||
}
|
||||
|
||||
output_processor = BinaryLabelProcessor(class_weight="balanced")
|
||||
|
||||
# Combine into sample processor
|
||||
sample_processor = SampleProcessor(
|
||||
input_processors=input_processors,
|
||||
output_processor=output_processor
|
||||
)
|
||||
|
||||
# Process dataset
|
||||
processed_samples = [sample_processor(s) for s in raw_samples]
|
||||
```
|
||||
|
||||
### Workflow 2: Sleep Staging from EEG
|
||||
|
||||
```python
|
||||
from pyhealth.data import (
|
||||
SignalProcessor,
|
||||
MultiClassLabelProcessor,
|
||||
SampleProcessor
|
||||
)
|
||||
|
||||
# Signal preprocessing
|
||||
signal_processor = SignalProcessor(
|
||||
sampling_rate=100,
|
||||
bandpass_filter=(0.3, 35), # EEG frequency range
|
||||
segment_length=30 # 30-second epochs
|
||||
)
|
||||
|
||||
# Label processing
|
||||
label_processor = MultiClassLabelProcessor(
|
||||
num_classes=5, # W, N1, N2, N3, REM
|
||||
class_weight="balanced"
|
||||
)
|
||||
|
||||
# Combine
|
||||
sample_processor = SampleProcessor(
|
||||
input_processors={"signal": signal_processor},
|
||||
output_processor=label_processor
|
||||
)
|
||||
```
|
||||
|
||||
### Workflow 3: Drug Recommendation
|
||||
|
||||
```python
|
||||
from pyhealth.data import (
|
||||
SequenceProcessor,
|
||||
MultiLabelProcessor,
|
||||
SampleProcessor
|
||||
)
|
||||
|
||||
# Input processing
|
||||
input_processors = {
|
||||
"diagnoses": SequenceProcessor(max_seq_length=50),
|
||||
"previous_medications": SequenceProcessor(max_seq_length=40)
|
||||
}
|
||||
|
||||
# Multi-label output (multiple drugs)
|
||||
output_processor = MultiLabelProcessor(
|
||||
num_labels=150, # number of possible drugs
|
||||
threshold=0.5
|
||||
)
|
||||
|
||||
sample_processor = SampleProcessor(
|
||||
input_processors=input_processors,
|
||||
output_processor=output_processor
|
||||
)
|
||||
```
|
||||
|
||||
### Workflow 4: Length of Stay Prediction
|
||||
|
||||
```python
|
||||
from pyhealth.data import (
|
||||
SequenceProcessor,
|
||||
NestedFloatsProcessor,
|
||||
RegressionLabelProcessor,
|
||||
SampleProcessor
|
||||
)
|
||||
|
||||
# Process different feature types
|
||||
input_processors = {
|
||||
"diagnoses": SequenceProcessor(max_seq_length=30),
|
||||
"procedures": SequenceProcessor(max_seq_length=20),
|
||||
"labs": NestedFloatsProcessor(
|
||||
normalization="z-score",
|
||||
fill_missing="mean"
|
||||
)
|
||||
}
|
||||
|
||||
# Regression target
|
||||
output_processor = RegressionLabelProcessor(
|
||||
normalization="log", # log-transform LOS
|
||||
clip_outliers=True
|
||||
)
|
||||
|
||||
sample_processor = SampleProcessor(
|
||||
input_processors=input_processors,
|
||||
output_processor=output_processor
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Sequence Processing
|
||||
|
||||
1. **Choose appropriate max_seq_length**: Balance between context and computation
|
||||
- Short sequences (20-50): Fast, less context
|
||||
- Medium sequences (50-100): Good balance
|
||||
- Long sequences (100+): More context, slower
|
||||
|
||||
2. **Truncation strategy**:
|
||||
- "post": Keep most recent events (recommended for clinical prediction)
|
||||
- "pre": Keep earliest events
|
||||
|
||||
3. **Padding strategy**:
|
||||
- "post": Pad at end (standard)
|
||||
- "pre": Pad at beginning
|
||||
|
||||
### Feature Encoding
|
||||
|
||||
1. **Vocabulary size**: Limit to frequent codes
|
||||
- `min_freq=5`: Include codes appearing ≥5 times
|
||||
- `max_vocab_size=10000`: Cap total vocabulary size
|
||||
|
||||
2. **Handle rare codes**: Group into "unknown" category
|
||||
|
||||
3. **Missing values**:
|
||||
- Imputation (mean, median, forward-fill)
|
||||
- Indicator variables
|
||||
- Special tokens
|
||||
|
||||
### Normalization
|
||||
|
||||
1. **Numeric features**: Always normalize
|
||||
- Z-score: Standard scaling (mean=0, std=1)
|
||||
- Min-max: Range scaling [0, 1]
|
||||
|
||||
2. **Compute statistics on training set only**: Prevent data leakage
|
||||
|
||||
3. **Apply same normalization to val/test sets**
|
||||
|
||||
### Class Imbalance
|
||||
|
||||
1. **Use class weighting**: `class_weight="balanced"`
|
||||
|
||||
2. **Consider oversampling**: For very rare positive cases
|
||||
|
||||
3. **Evaluate with appropriate metrics**: AUROC, AUPRC, F1
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
1. **Cache processed data**: Save preprocessing results
|
||||
|
||||
2. **Parallel processing**: Use `num_workers` for DataLoader
|
||||
|
||||
3. **Batch processing**: Process multiple samples at once
|
||||
|
||||
4. **Feature selection**: Remove low-information features
|
||||
|
||||
### Validation
|
||||
|
||||
1. **Check processed shapes**: Ensure correct dimensions
|
||||
|
||||
2. **Verify value ranges**: After normalization
|
||||
|
||||
3. **Inspect samples**: Manually review processed data
|
||||
|
||||
4. **Monitor memory usage**: Especially for large datasets
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Memory Error:**
|
||||
- Reduce `max_seq_length`
|
||||
- Use smaller batches
|
||||
- Process data in chunks
|
||||
- Enable caching to disk
|
||||
|
||||
**Slow Processing:**
|
||||
- Enable parallel processing (`num_workers`)
|
||||
- Cache preprocessed data
|
||||
- Reduce feature dimensionality
|
||||
- Use more efficient data types
|
||||
|
||||
**Shape Mismatch:**
|
||||
- Check sequence lengths
|
||||
- Verify padding configuration
|
||||
- Ensure consistent processor settings
|
||||
|
||||
**NaN Values:**
|
||||
- Handle missing data explicitly
|
||||
- Check normalization parameters
|
||||
- Verify imputation strategy
|
||||
|
||||
**Class Imbalance:**
|
||||
- Use class weighting
|
||||
- Consider oversampling
|
||||
- Adjust decision threshold
|
||||
- Use appropriate evaluation metrics
|
||||
379
skills/pyhealth/references/tasks.md
Normal file
379
skills/pyhealth/references/tasks.md
Normal file
@@ -0,0 +1,379 @@
|
||||
# PyHealth Clinical Prediction Tasks
|
||||
|
||||
## Overview
|
||||
|
||||
PyHealth provides 20+ predefined clinical prediction tasks for common healthcare AI applications. Each task function transforms raw patient data into structured input-output pairs for model training.
|
||||
|
||||
## Task Function Structure
|
||||
|
||||
All task functions inherit from `BaseTask` and provide:
|
||||
|
||||
- **input_schema**: Defines input features (diagnoses, medications, labs, etc.)
|
||||
- **output_schema**: Defines prediction targets (labels, values)
|
||||
- **pre_filter()**: Optional patient/visit filtering logic
|
||||
|
||||
**Usage Pattern:**
|
||||
```python
|
||||
from pyhealth.datasets import MIMIC4Dataset
|
||||
from pyhealth.tasks import mortality_prediction_mimic4_fn
|
||||
|
||||
dataset = MIMIC4Dataset(root="/path/to/data")
|
||||
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
|
||||
```
|
||||
|
||||
## Electronic Health Record (EHR) Tasks
|
||||
|
||||
### Mortality Prediction
|
||||
|
||||
**Purpose:** Predict patient death risk at next visit or within specified timeframe
|
||||
|
||||
**MIMIC-III Mortality** (`mortality_prediction_mimic3_fn`)
|
||||
- Predicts death at next hospital visit
|
||||
- Binary classification task
|
||||
- Input: Historical diagnoses, procedures, medications
|
||||
- Output: Binary label (deceased/alive)
|
||||
|
||||
**MIMIC-IV Mortality** (`mortality_prediction_mimic4_fn`)
|
||||
- Updated version for MIMIC-IV dataset
|
||||
- Enhanced feature set
|
||||
- Improved label quality
|
||||
|
||||
**eICU Mortality** (`mortality_prediction_eicu_fn`)
|
||||
- Multi-center ICU mortality prediction
|
||||
- Accounts for hospital-level variation
|
||||
|
||||
**OMOP Mortality** (`mortality_prediction_omop_fn`)
|
||||
- Standardized mortality prediction
|
||||
- Works with OMOP common data model
|
||||
|
||||
**In-Hospital Mortality** (`inhospital_mortality_prediction_mimic4_fn`)
|
||||
- Predicts death during current hospitalization
|
||||
- Real-time risk assessment
|
||||
- Earlier prediction window than next-visit mortality
|
||||
|
||||
**StageNet Mortality** (`mortality_prediction_mimic4_fn_stagenet`)
|
||||
- Specialized for StageNet model architecture
|
||||
- Temporal stage-aware prediction
|
||||
|
||||
### Hospital Readmission Prediction
|
||||
|
||||
**Purpose:** Identify patients at risk of hospital readmission within specified timeframe (typically 30 days)
|
||||
|
||||
**MIMIC-III Readmission** (`readmission_prediction_mimic3_fn`)
|
||||
- 30-day readmission prediction
|
||||
- Binary classification
|
||||
- Input: Diagnosis history, medications, demographics
|
||||
- Output: Binary label (readmitted/not readmitted)
|
||||
|
||||
**MIMIC-IV Readmission** (`readmission_prediction_mimic4_fn`)
|
||||
- Enhanced readmission features
|
||||
- Improved temporal modeling
|
||||
|
||||
**eICU Readmission** (`readmission_prediction_eicu_fn`)
|
||||
- ICU-specific readmission risk
|
||||
- Multi-site data
|
||||
|
||||
**OMOP Readmission** (`readmission_prediction_omop_fn`)
|
||||
- Standardized readmission prediction
|
||||
|
||||
### Length of Stay Prediction
|
||||
|
||||
**Purpose:** Estimate hospital stay duration for resource planning and patient management
|
||||
|
||||
**MIMIC-III Length of Stay** (`length_of_stay_prediction_mimic3_fn`)
|
||||
- Regression task
|
||||
- Input: Admission diagnoses, vitals, demographics
|
||||
- Output: Continuous value (days)
|
||||
|
||||
**MIMIC-IV Length of Stay** (`length_of_stay_prediction_mimic4_fn`)
|
||||
- Enhanced features for LOS prediction
|
||||
- Better temporal granularity
|
||||
|
||||
**eICU Length of Stay** (`length_of_stay_prediction_eicu_fn`)
|
||||
- ICU stay duration prediction
|
||||
- Multi-hospital data
|
||||
|
||||
**OMOP Length of Stay** (`length_of_stay_prediction_omop_fn`)
|
||||
- Standardized LOS prediction
|
||||
|
||||
### Drug Recommendation
|
||||
|
||||
**Purpose:** Suggest appropriate medications based on patient history and current conditions
|
||||
|
||||
**MIMIC-III Drug Recommendation** (`drug_recommendation_mimic3_fn`)
|
||||
- Multi-label classification
|
||||
- Input: Diagnoses, previous medications, demographics
|
||||
- Output: Set of recommended drug codes
|
||||
- Considers drug-drug interactions
|
||||
|
||||
**MIMIC-IV Drug Recommendation** (`drug_recommendation_mimic4_fn`)
|
||||
- Updated medication data
|
||||
- Enhanced interaction modeling
|
||||
|
||||
**eICU Drug Recommendation** (`drug_recommendation_eicu_fn`)
|
||||
- Critical care medication recommendations
|
||||
|
||||
**OMOP Drug Recommendation** (`drug_recommendation_omop_fn`)
|
||||
- Standardized drug recommendation
|
||||
|
||||
**Key Considerations:**
|
||||
- Handles polypharmacy scenarios
|
||||
- Multi-label prediction (multiple drugs per patient)
|
||||
- Can integrate with SafeDrug/GAMENet models for safety-aware recommendations
|
||||
|
||||
## Specialized Clinical Tasks
|
||||
|
||||
### Medical Coding
|
||||
|
||||
**MIMIC-III ICD-9 Coding** (`icd9_coding_mimic3_fn`)
|
||||
- Assigns ICD-9 diagnosis/procedure codes to clinical notes
|
||||
- Multi-label text classification
|
||||
- Input: Clinical text/documentation
|
||||
- Output: Set of ICD-9 codes
|
||||
- Supports both diagnosis and procedure coding
|
||||
|
||||
### Patient Linkage
|
||||
|
||||
**MIMIC-III Patient Linking** (`patient_linkage_mimic3_fn`)
|
||||
- Record matching and deduplication
|
||||
- Binary classification (same patient or not)
|
||||
- Input: Demographic and clinical features from two records
|
||||
- Output: Match probability
|
||||
|
||||
## Physiological Signal Tasks
|
||||
|
||||
### Sleep Staging
|
||||
|
||||
**Purpose:** Classify sleep stages from EEG/physiological signals for sleep disorder diagnosis
|
||||
|
||||
**ISRUC Sleep Staging** (`sleep_staging_isruc_fn`)
|
||||
- Multi-class classification (Wake, N1, N2, N3, REM)
|
||||
- Input: Multi-channel EEG signals
|
||||
- Output: Sleep stage per epoch (typically 30 seconds)
|
||||
|
||||
**SleepEDF Sleep Staging** (`sleep_staging_sleepedf_fn`)
|
||||
- Standard sleep staging task
|
||||
- PSG signal processing
|
||||
|
||||
**SHHS Sleep Staging** (`sleep_staging_shhs_fn`)
|
||||
- Large-scale sleep study data
|
||||
- Population-level sleep analysis
|
||||
|
||||
**Standardized Labels:**
|
||||
- Wake (W)
|
||||
- Non-REM Stage 1 (N1)
|
||||
- Non-REM Stage 2 (N2)
|
||||
- Non-REM Stage 3 (N3/Deep Sleep)
|
||||
- REM (Rapid Eye Movement)
|
||||
|
||||
### EEG Analysis
|
||||
|
||||
**Abnormality Detection** (`abnormality_detection_tuab_fn`)
|
||||
- Binary classification (normal/abnormal EEG)
|
||||
- Clinical screening application
|
||||
- Input: Multi-channel EEG recordings
|
||||
- Output: Binary label
|
||||
|
||||
**Event Detection** (`event_detection_tuev_fn`)
|
||||
- Identify specific EEG events (spikes, seizures)
|
||||
- Multi-class classification
|
||||
- Input: EEG time series
|
||||
- Output: Event type and timing
|
||||
|
||||
**Seizure Detection** (`seizure_detection_tusz_fn`)
|
||||
- Specialized epileptic seizure detection
|
||||
- Critical for epilepsy monitoring
|
||||
- Input: Continuous EEG
|
||||
- Output: Seizure/non-seizure classification
|
||||
|
||||
## Medical Imaging Tasks
|
||||
|
||||
### COVID-19 Chest X-ray Classification
|
||||
|
||||
**COVID-19 CXR** (`covid_classification_cxr_fn`)
|
||||
- Multi-class image classification
|
||||
- Classes: COVID-19, bacterial pneumonia, viral pneumonia, normal
|
||||
- Input: Chest X-ray images
|
||||
- Output: Disease classification
|
||||
|
||||
## Text-Based Tasks
|
||||
|
||||
### Medical Transcription Classification
|
||||
|
||||
**Medical Specialty Classification** (`medical_transcription_classification_fn`)
|
||||
- Classify clinical notes by medical specialty
|
||||
- Multi-class text classification
|
||||
- Input: Clinical transcription text
|
||||
- Output: Medical specialty (Cardiology, Neurology, etc.)
|
||||
|
||||
## Custom Task Creation
|
||||
|
||||
### Creating Custom Tasks
|
||||
|
||||
Define custom prediction tasks by specifying input/output schemas:
|
||||
|
||||
```python
|
||||
from pyhealth.tasks import BaseTask
|
||||
|
||||
def custom_task_fn(patient):
|
||||
"""Custom prediction task"""
|
||||
|
||||
# Define input features
|
||||
samples = []
|
||||
|
||||
for i, visit in enumerate(patient.visits):
|
||||
# Skip if not enough history
|
||||
if i < 2:
|
||||
continue
|
||||
|
||||
# Create input from historical visits
|
||||
input_info = {
|
||||
"diagnoses": [],
|
||||
"medications": [],
|
||||
"procedures": []
|
||||
}
|
||||
|
||||
# Collect features from previous visits
|
||||
for past_visit in patient.visits[:i]:
|
||||
for event in past_visit.events:
|
||||
if event.vocabulary == "ICD10CM":
|
||||
input_info["diagnoses"].append(event.code)
|
||||
elif event.vocabulary == "NDC":
|
||||
input_info["medications"].append(event.code)
|
||||
|
||||
# Define prediction target
|
||||
# Example: predict specific outcome at current visit
|
||||
output_info = {
|
||||
"label": 1 if some_condition else 0
|
||||
}
|
||||
|
||||
samples.append({
|
||||
"patient_id": patient.patient_id,
|
||||
"visit_id": visit.visit_id,
|
||||
"input_info": input_info,
|
||||
"output_info": output_info
|
||||
})
|
||||
|
||||
return samples
|
||||
|
||||
# Apply custom task
|
||||
sample_dataset = dataset.set_task(custom_task_fn)
|
||||
```
|
||||
|
||||
### Task Function Components
|
||||
|
||||
1. **Input Schema Definition**
|
||||
- Specify which features to extract
|
||||
- Define feature types (codes, sequences, values)
|
||||
- Set temporal windows
|
||||
|
||||
2. **Output Schema Definition**
|
||||
- Define prediction targets
|
||||
- Set label types (binary, multi-class, multi-label, regression)
|
||||
- Specify evaluation metrics
|
||||
|
||||
3. **Filtering Logic**
|
||||
- Exclude patients/visits with insufficient data
|
||||
- Apply inclusion/exclusion criteria
|
||||
- Handle missing data
|
||||
|
||||
4. **Sample Generation**
|
||||
- Create input-output pairs
|
||||
- Maintain patient/visit identifiers
|
||||
- Preserve temporal ordering
|
||||
|
||||
## Task Selection Guidelines
|
||||
|
||||
### Clinical Prediction Tasks
|
||||
**Use when:** Working with structured EHR data (diagnoses, medications, procedures)
|
||||
|
||||
**Datasets:** MIMIC-III, MIMIC-IV, eICU, OMOP
|
||||
|
||||
**Common tasks:**
|
||||
- Mortality prediction for risk stratification
|
||||
- Readmission prediction for care transition planning
|
||||
- Length of stay for resource allocation
|
||||
- Drug recommendation for clinical decision support
|
||||
|
||||
### Signal Processing Tasks
|
||||
**Use when:** Working with physiological time-series data
|
||||
|
||||
**Datasets:** SleepEDF, SHHS, ISRUC, TUEV, TUAB, TUSZ
|
||||
|
||||
**Common tasks:**
|
||||
- Sleep staging for sleep disorder diagnosis
|
||||
- EEG abnormality detection for screening
|
||||
- Seizure detection for epilepsy monitoring
|
||||
|
||||
### Imaging Tasks
|
||||
**Use when:** Working with medical images
|
||||
|
||||
**Datasets:** COVID-19 CXR
|
||||
|
||||
**Common tasks:**
|
||||
- Disease classification from radiographs
|
||||
- Abnormality detection
|
||||
|
||||
### Text Tasks
|
||||
**Use when:** Working with clinical notes and documentation
|
||||
|
||||
**Datasets:** Medical Transcriptions, MIMIC-III (with notes)
|
||||
|
||||
**Common tasks:**
|
||||
- Medical coding from clinical text
|
||||
- Specialty classification
|
||||
- Clinical information extraction
|
||||
|
||||
## Task Output Structure
|
||||
|
||||
All task functions return `SampleDataset` with:
|
||||
|
||||
```python
|
||||
sample = {
|
||||
"patient_id": "unique_patient_id",
|
||||
"visit_id": "unique_visit_id", # if applicable
|
||||
"input_info": {
|
||||
# Input features (diagnoses, medications, etc.)
|
||||
},
|
||||
"output_info": {
|
||||
# Prediction targets (labels, values)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Integration with Models
|
||||
|
||||
Tasks define the input/output contract for models:
|
||||
|
||||
```python
|
||||
from pyhealth.datasets import MIMIC4Dataset
|
||||
from pyhealth.tasks import mortality_prediction_mimic4_fn
|
||||
from pyhealth.models import Transformer
|
||||
|
||||
# 1. Create task-specific dataset
|
||||
dataset = MIMIC4Dataset(root="/path/to/data")
|
||||
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
|
||||
|
||||
# 2. Model automatically adapts to task schema
|
||||
model = Transformer(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary", # matches task output
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Match task to clinical question**: Choose predefined tasks when available for standardized benchmarking
|
||||
|
||||
2. **Consider temporal windows**: Ensure sufficient history for meaningful predictions
|
||||
|
||||
3. **Handle class imbalance**: Many clinical outcomes are rare (mortality, readmission)
|
||||
|
||||
4. **Validate clinical relevance**: Ensure prediction windows align with clinical decision-making timelines
|
||||
|
||||
5. **Use appropriate metrics**: Different tasks require different evaluation metrics (AUROC for binary, macro-F1 for multi-class)
|
||||
|
||||
6. **Document exclusion criteria**: Track which patients/visits are filtered and why
|
||||
|
||||
7. **Preserve patient privacy**: Always use de-identified data and follow HIPAA/GDPR guidelines
|
||||
648
skills/pyhealth/references/training_evaluation.md
Normal file
648
skills/pyhealth/references/training_evaluation.md
Normal file
@@ -0,0 +1,648 @@
|
||||
# PyHealth Training, Evaluation, and Interpretability
|
||||
|
||||
## Overview
|
||||
|
||||
PyHealth provides comprehensive tools for training models, evaluating predictions, ensuring model reliability, and interpreting results for clinical applications.
|
||||
|
||||
## Trainer Class
|
||||
|
||||
### Core Functionality
|
||||
|
||||
The `Trainer` class manages the complete model training and evaluation workflow with PyTorch integration.
|
||||
|
||||
**Initialization:**
|
||||
```python
|
||||
from pyhealth.trainer import Trainer
|
||||
|
||||
trainer = Trainer(
|
||||
model=model, # PyHealth or PyTorch model
|
||||
device="cuda", # or "cpu"
|
||||
)
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
**train() method**
|
||||
|
||||
Trains models with comprehensive monitoring and checkpointing.
|
||||
|
||||
**Parameters:**
|
||||
- `train_dataloader`: Training data loader
|
||||
- `val_dataloader`: Validation data loader (optional)
|
||||
- `test_dataloader`: Test data loader (optional)
|
||||
- `epochs`: Number of training epochs
|
||||
- `optimizer`: Optimizer instance or class
|
||||
- `learning_rate`: Learning rate (default: 1e-3)
|
||||
- `weight_decay`: L2 regularization (default: 0)
|
||||
- `max_grad_norm`: Gradient clipping threshold
|
||||
- `monitor`: Metric to monitor (e.g., "pr_auc_score")
|
||||
- `monitor_criterion`: "max" or "min"
|
||||
- `save_path`: Checkpoint save directory
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
trainer.train(
|
||||
train_dataloader=train_loader,
|
||||
val_dataloader=val_loader,
|
||||
test_dataloader=test_loader,
|
||||
epochs=50,
|
||||
optimizer=torch.optim.Adam,
|
||||
learning_rate=1e-3,
|
||||
weight_decay=1e-5,
|
||||
max_grad_norm=5.0,
|
||||
monitor="pr_auc_score",
|
||||
monitor_criterion="max",
|
||||
save_path="./checkpoints"
|
||||
)
|
||||
```
|
||||
|
||||
**Training Features:**
|
||||
|
||||
1. **Automatic Checkpointing**: Saves best model based on monitored metric
|
||||
|
||||
2. **Early Stopping**: Stops training if no improvement
|
||||
|
||||
3. **Gradient Clipping**: Prevents exploding gradients
|
||||
|
||||
4. **Progress Tracking**: Displays training progress and metrics
|
||||
|
||||
5. **Multi-GPU Support**: Automatic device placement
|
||||
|
||||
### Inference
|
||||
|
||||
**inference() method**
|
||||
|
||||
Performs predictions on datasets.
|
||||
|
||||
**Parameters:**
|
||||
- `dataloader`: Data loader for inference
|
||||
- `additional_outputs`: List of additional outputs to return
|
||||
- `return_patient_ids`: Return patient identifiers
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
predictions = trainer.inference(
|
||||
dataloader=test_loader,
|
||||
additional_outputs=["attention_weights", "embeddings"],
|
||||
return_patient_ids=True
|
||||
)
|
||||
```
|
||||
|
||||
**Returns:**
|
||||
- `y_pred`: Model predictions
|
||||
- `y_true`: Ground truth labels
|
||||
- `patient_ids`: Patient identifiers (if requested)
|
||||
- Additional outputs (if specified)
|
||||
|
||||
### Evaluation
|
||||
|
||||
**evaluate() method**
|
||||
|
||||
Computes comprehensive evaluation metrics.
|
||||
|
||||
**Parameters:**
|
||||
- `dataloader`: Data loader for evaluation
|
||||
- `metrics`: List of metric functions
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.metrics import binary_metrics_fn
|
||||
|
||||
results = trainer.evaluate(
|
||||
dataloader=test_loader,
|
||||
metrics=["accuracy", "pr_auc_score", "roc_auc_score", "f1_score"]
|
||||
)
|
||||
|
||||
print(results)
|
||||
# Output: {'accuracy': 0.85, 'pr_auc_score': 0.78, 'roc_auc_score': 0.82, 'f1_score': 0.73}
|
||||
```
|
||||
|
||||
### Checkpoint Management
|
||||
|
||||
**save() method**
|
||||
```python
|
||||
trainer.save("./models/best_model.pt")
|
||||
```
|
||||
|
||||
**load() method**
|
||||
```python
|
||||
trainer.load("./models/best_model.pt")
|
||||
```
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
### Binary Classification Metrics
|
||||
|
||||
**Available Metrics:**
|
||||
- `accuracy`: Overall accuracy
|
||||
- `precision`: Positive predictive value
|
||||
- `recall`: Sensitivity/True positive rate
|
||||
- `f1_score`: F1 score (harmonic mean of precision and recall)
|
||||
- `roc_auc_score`: Area under ROC curve
|
||||
- `pr_auc_score`: Area under precision-recall curve
|
||||
- `cohen_kappa`: Inter-rater reliability
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.metrics import binary_metrics_fn
|
||||
|
||||
# Comprehensive binary metrics
|
||||
metrics = binary_metrics_fn(
|
||||
y_true=labels,
|
||||
y_pred=predictions,
|
||||
metrics=["accuracy", "f1_score", "pr_auc_score", "roc_auc_score"]
|
||||
)
|
||||
```
|
||||
|
||||
**Threshold Selection:**
|
||||
```python
|
||||
# Default threshold: 0.5
|
||||
predictions_binary = (predictions > 0.5).astype(int)
|
||||
|
||||
# Optimal threshold by F1
|
||||
from sklearn.metrics import f1_score
|
||||
thresholds = np.arange(0.1, 0.9, 0.05)
|
||||
f1_scores = [f1_score(y_true, (y_pred > t).astype(int)) for t in thresholds]
|
||||
optimal_threshold = thresholds[np.argmax(f1_scores)]
|
||||
```
|
||||
|
||||
**Best Practices:**
|
||||
- **Use AUROC**: Overall model discrimination
|
||||
- **Use AUPRC**: Especially for imbalanced classes
|
||||
- **Use F1**: Balance precision and recall
|
||||
- **Report confidence intervals**: Bootstrap resampling
|
||||
|
||||
### Multi-Class Classification Metrics
|
||||
|
||||
**Available Metrics:**
|
||||
- `accuracy`: Overall accuracy
|
||||
- `macro_f1`: Unweighted mean F1 across classes
|
||||
- `micro_f1`: Global F1 (total TP, FP, FN)
|
||||
- `weighted_f1`: Weighted mean F1 by class frequency
|
||||
- `cohen_kappa`: Multi-class kappa
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.metrics import multiclass_metrics_fn
|
||||
|
||||
metrics = multiclass_metrics_fn(
|
||||
y_true=labels,
|
||||
y_pred=predictions,
|
||||
metrics=["accuracy", "macro_f1", "weighted_f1"]
|
||||
)
|
||||
```
|
||||
|
||||
**Per-Class Metrics:**
|
||||
```python
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
print(classification_report(y_true, y_pred,
|
||||
target_names=["Wake", "N1", "N2", "N3", "REM"]))
|
||||
```
|
||||
|
||||
**Confusion Matrix:**
|
||||
```python
|
||||
from sklearn.metrics import confusion_matrix
|
||||
import seaborn as sns
|
||||
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
sns.heatmap(cm, annot=True, fmt='d')
|
||||
```
|
||||
|
||||
### Multi-Label Classification Metrics
|
||||
|
||||
**Available Metrics:**
|
||||
- `jaccard_score`: Intersection over union
|
||||
- `hamming_loss`: Fraction of incorrect labels
|
||||
- `example_f1`: F1 per example (micro average)
|
||||
- `label_f1`: F1 per label (macro average)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.metrics import multilabel_metrics_fn
|
||||
|
||||
# y_pred: [n_samples, n_labels] binary matrix
|
||||
metrics = multilabel_metrics_fn(
|
||||
y_true=label_matrix,
|
||||
y_pred=pred_matrix,
|
||||
metrics=["jaccard_score", "example_f1", "label_f1"]
|
||||
)
|
||||
```
|
||||
|
||||
**Drug Recommendation Metrics:**
|
||||
```python
|
||||
# Jaccard similarity (intersection/union)
|
||||
jaccard = len(set(true_drugs) & set(pred_drugs)) / len(set(true_drugs) | set(pred_drugs))
|
||||
|
||||
# Precision@k: Precision for top-k predictions
|
||||
def precision_at_k(y_true, y_pred, k=10):
|
||||
top_k_pred = y_pred.argsort()[-k:]
|
||||
return len(set(y_true) & set(top_k_pred)) / k
|
||||
```
|
||||
|
||||
### Regression Metrics
|
||||
|
||||
**Available Metrics:**
|
||||
- `mean_absolute_error`: Average absolute error
|
||||
- `mean_squared_error`: Average squared error
|
||||
- `root_mean_squared_error`: RMSE
|
||||
- `r2_score`: Coefficient of determination
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.metrics import regression_metrics_fn
|
||||
|
||||
metrics = regression_metrics_fn(
|
||||
y_true=true_values,
|
||||
y_pred=predictions,
|
||||
metrics=["mae", "rmse", "r2"]
|
||||
)
|
||||
```
|
||||
|
||||
**Percentage Error Metrics:**
|
||||
```python
|
||||
# Mean Absolute Percentage Error
|
||||
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
|
||||
|
||||
# Median Absolute Percentage Error (robust to outliers)
|
||||
medape = np.median(np.abs((y_true - y_pred) / y_true)) * 100
|
||||
```
|
||||
|
||||
### Fairness Metrics
|
||||
|
||||
**Purpose:** Assess model bias across demographic groups
|
||||
|
||||
**Available Metrics:**
|
||||
- `demographic_parity`: Equal positive prediction rates
|
||||
- `equalized_odds`: Equal TPR and FPR across groups
|
||||
- `equal_opportunity`: Equal TPR across groups
|
||||
- `predictive_parity`: Equal PPV across groups
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.metrics import fairness_metrics_fn
|
||||
|
||||
fairness_results = fairness_metrics_fn(
|
||||
y_true=labels,
|
||||
y_pred=predictions,
|
||||
sensitive_attributes=demographics, # e.g., race, gender
|
||||
metrics=["demographic_parity", "equalized_odds"]
|
||||
)
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Evaluate fairness across gender
|
||||
male_mask = (demographics == "male")
|
||||
female_mask = (demographics == "female")
|
||||
|
||||
male_tpr = recall_score(y_true[male_mask], y_pred[male_mask])
|
||||
female_tpr = recall_score(y_true[female_mask], y_pred[female_mask])
|
||||
|
||||
tpr_disparity = abs(male_tpr - female_tpr)
|
||||
print(f"TPR disparity: {tpr_disparity:.3f}")
|
||||
```
|
||||
|
||||
## Calibration and Uncertainty Quantification
|
||||
|
||||
### Model Calibration
|
||||
|
||||
**Purpose:** Ensure predicted probabilities match actual frequencies
|
||||
|
||||
**Calibration Plot:**
|
||||
```python
|
||||
from sklearn.calibration import calibration_curve
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fraction_of_positives, mean_predicted_value = calibration_curve(
|
||||
y_true, y_prob, n_bins=10
|
||||
)
|
||||
|
||||
plt.plot(mean_predicted_value, fraction_of_positives, marker='o')
|
||||
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect calibration')
|
||||
plt.xlabel('Mean predicted probability')
|
||||
plt.ylabel('Fraction of positives')
|
||||
plt.legend()
|
||||
```
|
||||
|
||||
**Expected Calibration Error (ECE):**
|
||||
```python
|
||||
def expected_calibration_error(y_true, y_prob, n_bins=10):
|
||||
"""Compute ECE"""
|
||||
bins = np.linspace(0, 1, n_bins + 1)
|
||||
bin_indices = np.digitize(y_prob, bins) - 1
|
||||
|
||||
ece = 0
|
||||
for i in range(n_bins):
|
||||
mask = bin_indices == i
|
||||
if mask.sum() > 0:
|
||||
bin_accuracy = y_true[mask].mean()
|
||||
bin_confidence = y_prob[mask].mean()
|
||||
ece += mask.sum() / len(y_true) * abs(bin_accuracy - bin_confidence)
|
||||
|
||||
return ece
|
||||
```
|
||||
|
||||
**Calibration Methods:**
|
||||
|
||||
1. **Platt Scaling**: Logistic regression on validation predictions
|
||||
```python
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
calibrator = LogisticRegression()
|
||||
calibrator.fit(val_predictions.reshape(-1, 1), val_labels)
|
||||
calibrated_probs = calibrator.predict_proba(test_predictions.reshape(-1, 1))[:, 1]
|
||||
```
|
||||
|
||||
2. **Isotonic Regression**: Non-parametric calibration
|
||||
```python
|
||||
from sklearn.isotonic import IsotonicRegression
|
||||
|
||||
calibrator = IsotonicRegression(out_of_bounds='clip')
|
||||
calibrator.fit(val_predictions, val_labels)
|
||||
calibrated_probs = calibrator.predict(test_predictions)
|
||||
```
|
||||
|
||||
3. **Temperature Scaling**: Scale logits before softmax
|
||||
```python
|
||||
def find_temperature(logits, labels):
|
||||
"""Find optimal temperature parameter"""
|
||||
from scipy.optimize import minimize
|
||||
|
||||
def nll(temp):
|
||||
scaled_logits = logits / temp
|
||||
probs = torch.softmax(scaled_logits, dim=1)
|
||||
return F.cross_entropy(probs, labels).item()
|
||||
|
||||
result = minimize(nll, x0=1.0, method='BFGS')
|
||||
return result.x[0]
|
||||
|
||||
temperature = find_temperature(val_logits, val_labels)
|
||||
calibrated_logits = test_logits / temperature
|
||||
```
|
||||
|
||||
### Uncertainty Quantification
|
||||
|
||||
**Conformal Prediction:**
|
||||
|
||||
Provide prediction sets with guaranteed coverage.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.metrics import prediction_set_metrics_fn
|
||||
|
||||
# Calibrate on validation set
|
||||
scores = 1 - val_predictions[np.arange(len(val_labels)), val_labels]
|
||||
quantile_level = np.quantile(scores, 0.9) # 90% coverage
|
||||
|
||||
# Generate prediction sets on test set
|
||||
prediction_sets = test_predictions > (1 - quantile_level)
|
||||
|
||||
# Evaluate
|
||||
metrics = prediction_set_metrics_fn(
|
||||
y_true=test_labels,
|
||||
prediction_sets=prediction_sets,
|
||||
metrics=["coverage", "average_size"]
|
||||
)
|
||||
```
|
||||
|
||||
**Monte Carlo Dropout:**
|
||||
|
||||
Estimate uncertainty through dropout at inference.
|
||||
|
||||
```python
|
||||
def predict_with_uncertainty(model, dataloader, num_samples=20):
|
||||
"""Predict with uncertainty using MC dropout"""
|
||||
model.train() # Keep dropout active
|
||||
|
||||
predictions = []
|
||||
for _ in range(num_samples):
|
||||
batch_preds = []
|
||||
for batch in dataloader:
|
||||
with torch.no_grad():
|
||||
output = model(batch)
|
||||
batch_preds.append(output)
|
||||
predictions.append(torch.cat(batch_preds))
|
||||
|
||||
predictions = torch.stack(predictions)
|
||||
mean_pred = predictions.mean(dim=0)
|
||||
std_pred = predictions.std(dim=0) # Uncertainty
|
||||
|
||||
return mean_pred, std_pred
|
||||
```
|
||||
|
||||
**Ensemble Uncertainty:**
|
||||
|
||||
```python
|
||||
# Train multiple models
|
||||
models = [train_model(seed=i) for i in range(5)]
|
||||
|
||||
# Predict with ensemble
|
||||
ensemble_preds = []
|
||||
for model in models:
|
||||
pred = model.predict(test_data)
|
||||
ensemble_preds.append(pred)
|
||||
|
||||
mean_pred = np.mean(ensemble_preds, axis=0)
|
||||
std_pred = np.std(ensemble_preds, axis=0) # Uncertainty
|
||||
```
|
||||
|
||||
## Interpretability
|
||||
|
||||
### Attention Visualization
|
||||
|
||||
**For Transformer and RETAIN models:**
|
||||
|
||||
```python
|
||||
# Get attention weights during inference
|
||||
outputs = trainer.inference(
|
||||
test_loader,
|
||||
additional_outputs=["attention_weights"]
|
||||
)
|
||||
|
||||
attention = outputs["attention_weights"]
|
||||
|
||||
# Visualize attention for sample
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sample_idx = 0
|
||||
sample_attention = attention[sample_idx] # [seq_length, seq_length]
|
||||
|
||||
sns.heatmap(sample_attention, cmap='viridis')
|
||||
plt.xlabel('Key Position')
|
||||
plt.ylabel('Query Position')
|
||||
plt.title('Attention Weights')
|
||||
plt.show()
|
||||
```
|
||||
|
||||
**RETAIN Interpretation:**
|
||||
|
||||
```python
|
||||
# RETAIN provides visit-level and feature-level attention
|
||||
visit_attention = outputs["visit_attention"] # Which visits are important
|
||||
feature_attention = outputs["feature_attention"] # Which features are important
|
||||
|
||||
# Find most influential visit
|
||||
most_important_visit = visit_attention[sample_idx].argmax()
|
||||
|
||||
# Find most influential features in that visit
|
||||
important_features = feature_attention[sample_idx, most_important_visit].argsort()[-10:]
|
||||
```
|
||||
|
||||
### Feature Importance
|
||||
|
||||
**Permutation Importance:**
|
||||
|
||||
```python
|
||||
from sklearn.inspection import permutation_importance
|
||||
|
||||
def get_predictions(model, X):
|
||||
return model.predict(X)
|
||||
|
||||
result = permutation_importance(
|
||||
model, X_test, y_test,
|
||||
n_repeats=10,
|
||||
scoring='roc_auc'
|
||||
)
|
||||
|
||||
# Sort features by importance
|
||||
indices = result.importances_mean.argsort()[::-1]
|
||||
for i in indices[:10]:
|
||||
print(f"{feature_names[i]}: {result.importances_mean[i]:.3f}")
|
||||
```
|
||||
|
||||
**SHAP Values:**
|
||||
|
||||
```python
|
||||
import shap
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.DeepExplainer(model, train_data)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(test_data)
|
||||
|
||||
# Visualize
|
||||
shap.summary_plot(shap_values, test_data, feature_names=feature_names)
|
||||
```
|
||||
|
||||
### ChEFER (Clinical Health Event Feature Extraction and Ranking)
|
||||
|
||||
**PyHealth's Interpretability Tool:**
|
||||
|
||||
```python
|
||||
from pyhealth.explain import ChEFER
|
||||
|
||||
explainer = ChEFER(model=model, dataset=test_dataset)
|
||||
|
||||
# Get feature importance for prediction
|
||||
importance_scores = explainer.explain(
|
||||
patient_id="patient_123",
|
||||
visit_id="visit_456"
|
||||
)
|
||||
|
||||
# Visualize top features
|
||||
explainer.plot_importance(importance_scores, top_k=20)
|
||||
```
|
||||
|
||||
## Complete Training Pipeline Example
|
||||
|
||||
```python
|
||||
from pyhealth.datasets import MIMIC4Dataset
|
||||
from pyhealth.tasks import mortality_prediction_mimic4_fn
|
||||
from pyhealth.datasets import split_by_patient, get_dataloader
|
||||
from pyhealth.models import Transformer
|
||||
from pyhealth.trainer import Trainer
|
||||
from pyhealth.metrics import binary_metrics_fn
|
||||
|
||||
# 1. Load and prepare data
|
||||
dataset = MIMIC4Dataset(root="/path/to/mimic4")
|
||||
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
|
||||
|
||||
# 2. Split data
|
||||
train_data, val_data, test_data = split_by_patient(
|
||||
sample_dataset, ratios=[0.7, 0.1, 0.2], seed=42
|
||||
)
|
||||
|
||||
# 3. Create data loaders
|
||||
train_loader = get_dataloader(train_data, batch_size=64, shuffle=True)
|
||||
val_loader = get_dataloader(val_data, batch_size=64, shuffle=False)
|
||||
test_loader = get_dataloader(test_data, batch_size=64, shuffle=False)
|
||||
|
||||
# 4. Initialize model
|
||||
model = Transformer(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "procedures", "medications"],
|
||||
mode="binary",
|
||||
embedding_dim=128,
|
||||
num_heads=8,
|
||||
num_layers=3,
|
||||
dropout=0.3
|
||||
)
|
||||
|
||||
# 5. Train model
|
||||
trainer = Trainer(model=model, device="cuda")
|
||||
trainer.train(
|
||||
train_dataloader=train_loader,
|
||||
val_dataloader=val_loader,
|
||||
epochs=50,
|
||||
optimizer=torch.optim.Adam,
|
||||
learning_rate=1e-3,
|
||||
weight_decay=1e-5,
|
||||
monitor="pr_auc_score",
|
||||
monitor_criterion="max",
|
||||
save_path="./checkpoints/mortality_model"
|
||||
)
|
||||
|
||||
# 6. Evaluate on test set
|
||||
test_results = trainer.evaluate(
|
||||
test_loader,
|
||||
metrics=["accuracy", "precision", "recall", "f1_score",
|
||||
"roc_auc_score", "pr_auc_score"]
|
||||
)
|
||||
|
||||
print("Test Results:")
|
||||
for metric, value in test_results.items():
|
||||
print(f"{metric}: {value:.4f}")
|
||||
|
||||
# 7. Get predictions for analysis
|
||||
predictions = trainer.inference(test_loader, return_patient_ids=True)
|
||||
y_pred, y_true, patient_ids = predictions
|
||||
|
||||
# 8. Calibration analysis
|
||||
from sklearn.calibration import calibration_curve
|
||||
|
||||
fraction_pos, mean_pred = calibration_curve(y_true, y_pred, n_bins=10)
|
||||
ece = expected_calibration_error(y_true, y_pred)
|
||||
print(f"Expected Calibration Error: {ece:.4f}")
|
||||
|
||||
# 9. Save final model
|
||||
trainer.save("./models/mortality_transformer_final.pt")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Training
|
||||
|
||||
1. **Monitor multiple metrics**: Track both loss and task-specific metrics
|
||||
2. **Use validation set**: Prevent overfitting with early stopping
|
||||
3. **Gradient clipping**: Stabilize training (max_grad_norm=5.0)
|
||||
4. **Learning rate scheduling**: Reduce LR on plateau
|
||||
5. **Checkpoint best model**: Save based on validation performance
|
||||
|
||||
### Evaluation
|
||||
|
||||
1. **Use task-appropriate metrics**: AUROC/AUPRC for binary, macro-F1 for imbalanced multi-class
|
||||
2. **Report confidence intervals**: Bootstrap or cross-validation
|
||||
3. **Stratified evaluation**: Report metrics by subgroups
|
||||
4. **Clinical metrics**: Include clinically relevant thresholds
|
||||
5. **Fairness assessment**: Evaluate across demographic groups
|
||||
|
||||
### Deployment
|
||||
|
||||
1. **Calibrate predictions**: Ensure probabilities are reliable
|
||||
2. **Quantify uncertainty**: Provide confidence estimates
|
||||
3. **Monitor performance**: Track metrics in production
|
||||
4. **Handle distribution shift**: Detect when data changes
|
||||
5. **Interpretability**: Provide explanations for predictions
|
||||
Reference in New Issue
Block a user