595 lines
14 KiB
Markdown
595 lines
14 KiB
Markdown
# PyHealth Models
|
|
|
|
## Overview
|
|
|
|
PyHealth provides 33+ models for healthcare prediction tasks, ranging from simple baselines to state-of-the-art deep learning architectures. Models are organized into general-purpose architectures and healthcare-specific models.
|
|
|
|
## Model Base Class
|
|
|
|
All models inherit from `BaseModel` with standard PyTorch functionality:
|
|
|
|
**Key Attributes:**
|
|
- `dataset`: Associated SampleDataset
|
|
- `feature_keys`: Input features to use (e.g., ["diagnoses", "medications"])
|
|
- `mode`: Task type ("binary", "multiclass", "multilabel", "regression")
|
|
- `embedding_dim`: Feature embedding dimension
|
|
- `device`: Computation device (CPU/GPU)
|
|
|
|
**Key Methods:**
|
|
- `forward()`: Model forward pass
|
|
- `train_step()`: Single training iteration
|
|
- `eval_step()`: Single evaluation iteration
|
|
- `save()`: Save model checkpoint
|
|
- `load()`: Load model checkpoint
|
|
|
|
## General-Purpose Models
|
|
|
|
### Baseline Models
|
|
|
|
**Logistic Regression** (`LogisticRegression`)
|
|
- Linear classifier with mean pooling
|
|
- Simple baseline for comparison
|
|
- Fast training and inference
|
|
- Good for interpretability
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import LogisticRegression
|
|
|
|
model = LogisticRegression(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="binary"
|
|
)
|
|
```
|
|
|
|
**Multi-Layer Perceptron** (`MLP`)
|
|
- Feedforward neural network
|
|
- Configurable hidden layers
|
|
- Supports mean/sum/max pooling
|
|
- Good baseline for structured data
|
|
|
|
**Parameters:**
|
|
- `hidden_dim`: Hidden layer size
|
|
- `num_layers`: Number of hidden layers
|
|
- `dropout`: Dropout rate
|
|
- `pooling`: Aggregation method ("mean", "sum", "max")
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import MLP
|
|
|
|
model = MLP(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="binary",
|
|
hidden_dim=128,
|
|
num_layers=3,
|
|
dropout=0.5
|
|
)
|
|
```
|
|
|
|
### Convolutional Neural Networks
|
|
|
|
**CNN** (`CNN`)
|
|
- Convolutional layers for pattern detection
|
|
- Effective for sequential and spatial data
|
|
- Captures local temporal patterns
|
|
- Parameter efficient
|
|
|
|
**Architecture:**
|
|
- Multiple 1D convolutional layers
|
|
- Max pooling for dimension reduction
|
|
- Fully connected output layers
|
|
|
|
**Parameters:**
|
|
- `num_filters`: Number of convolutional filters
|
|
- `kernel_size`: Convolution kernel size
|
|
- `num_layers`: Number of conv layers
|
|
- `dropout`: Dropout rate
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import CNN
|
|
|
|
model = CNN(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="binary",
|
|
num_filters=64,
|
|
kernel_size=3,
|
|
num_layers=3
|
|
)
|
|
```
|
|
|
|
**Temporal Convolutional Networks** (`TCN`)
|
|
- Dilated convolutions for long-range dependencies
|
|
- Causal convolutions (no future information leakage)
|
|
- Efficient for long sequences
|
|
- Good for time-series prediction
|
|
|
|
**Advantages:**
|
|
- Captures long-term dependencies
|
|
- Parallelizable (faster than RNNs)
|
|
- Stable gradients
|
|
|
|
### Recurrent Neural Networks
|
|
|
|
**RNN** (`RNN`)
|
|
- Basic recurrent architecture
|
|
- Supports LSTM, GRU, RNN variants
|
|
- Sequential processing
|
|
- Captures temporal dependencies
|
|
|
|
**Parameters:**
|
|
- `rnn_type`: "LSTM", "GRU", or "RNN"
|
|
- `hidden_dim`: Hidden state dimension
|
|
- `num_layers`: Number of recurrent layers
|
|
- `dropout`: Dropout rate
|
|
- `bidirectional`: Use bidirectional RNN
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import RNN
|
|
|
|
model = RNN(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="binary",
|
|
rnn_type="LSTM",
|
|
hidden_dim=128,
|
|
num_layers=2,
|
|
bidirectional=True
|
|
)
|
|
```
|
|
|
|
**Best for:**
|
|
- Sequential clinical events
|
|
- Temporal pattern learning
|
|
- Variable-length sequences
|
|
|
|
### Transformer Models
|
|
|
|
**Transformer** (`Transformer`)
|
|
- Self-attention mechanism
|
|
- Parallel processing of sequences
|
|
- State-of-the-art performance
|
|
- Effective for long-range dependencies
|
|
|
|
**Architecture:**
|
|
- Multi-head self-attention
|
|
- Position embeddings
|
|
- Feed-forward networks
|
|
- Layer normalization
|
|
|
|
**Parameters:**
|
|
- `num_heads`: Number of attention heads
|
|
- `num_layers`: Number of transformer layers
|
|
- `hidden_dim`: Hidden dimension
|
|
- `dropout`: Dropout rate
|
|
- `max_seq_length`: Maximum sequence length
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import Transformer
|
|
|
|
model = Transformer(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="binary",
|
|
num_heads=8,
|
|
num_layers=6,
|
|
hidden_dim=256,
|
|
dropout=0.1
|
|
)
|
|
```
|
|
|
|
**TransformersModel** (`TransformersModel`)
|
|
- Integration with HuggingFace transformers
|
|
- Pre-trained language models for clinical text
|
|
- Fine-tuning for healthcare tasks
|
|
- Examples: BERT, RoBERTa, BioClinicalBERT
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import TransformersModel
|
|
|
|
model = TransformersModel(
|
|
dataset=sample_dataset,
|
|
feature_keys=["text"],
|
|
mode="multiclass",
|
|
pretrained_model="emilyalsentzer/Bio_ClinicalBERT"
|
|
)
|
|
```
|
|
|
|
### Graph Neural Networks
|
|
|
|
**GNN** (`GNN`)
|
|
- Graph-based learning
|
|
- Models relationships between entities
|
|
- Supports GAT (Graph Attention) and GCN (Graph Convolutional)
|
|
|
|
**Use Cases:**
|
|
- Drug-drug interactions
|
|
- Patient similarity networks
|
|
- Knowledge graph integration
|
|
- Comorbidity relationships
|
|
|
|
**Parameters:**
|
|
- `gnn_type`: "GAT" or "GCN"
|
|
- `hidden_dim`: Hidden dimension
|
|
- `num_layers`: Number of GNN layers
|
|
- `dropout`: Dropout rate
|
|
- `num_heads`: Attention heads (for GAT)
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import GNN
|
|
|
|
model = GNN(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="multilabel",
|
|
gnn_type="GAT",
|
|
hidden_dim=128,
|
|
num_layers=3,
|
|
num_heads=4
|
|
)
|
|
```
|
|
|
|
## Healthcare-Specific Models
|
|
|
|
### Interpretable Clinical Models
|
|
|
|
**RETAIN** (`RETAIN`)
|
|
- Reverse time attention mechanism
|
|
- Highly interpretable predictions
|
|
- Visit-level and event-level attention
|
|
- Identifies influential clinical events
|
|
|
|
**Key Features:**
|
|
- Two-level attention (visits and features)
|
|
- Temporal decay modeling
|
|
- Clinically meaningful explanations
|
|
- Published in NeurIPS 2016
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import RETAIN
|
|
|
|
model = RETAIN(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="binary",
|
|
hidden_dim=128
|
|
)
|
|
|
|
# Get attention weights for interpretation
|
|
outputs = model(batch)
|
|
visit_attention = outputs["visit_attention"]
|
|
feature_attention = outputs["feature_attention"]
|
|
```
|
|
|
|
**Best for:**
|
|
- Mortality prediction
|
|
- Readmission prediction
|
|
- Clinical risk scoring
|
|
- Interpretable predictions
|
|
|
|
**AdaCare** (`AdaCare`)
|
|
- Adaptive care model with feature calibration
|
|
- Disease-specific attention
|
|
- Handles irregular time intervals
|
|
- Interpretable feature importance
|
|
|
|
**ConCare** (`ConCare`)
|
|
- Cross-visit convolutional attention
|
|
- Temporal convolutional feature extraction
|
|
- Multi-level attention mechanism
|
|
- Good for longitudinal EHR modeling
|
|
|
|
### Medication Recommendation Models
|
|
|
|
**GAMENet** (`GAMENet`)
|
|
- Graph-based medication recommendation
|
|
- Drug-drug interaction modeling
|
|
- Memory network for patient history
|
|
- Multi-hop reasoning
|
|
|
|
**Architecture:**
|
|
- Drug knowledge graph
|
|
- Memory-augmented neural network
|
|
- DDI-aware prediction
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import GAMENet
|
|
|
|
model = GAMENet(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="multilabel",
|
|
embedding_dim=128,
|
|
ddi_adj_path="/path/to/ddi_adjacency_matrix.pkl"
|
|
)
|
|
```
|
|
|
|
**MICRON** (`MICRON`)
|
|
- Medication recommendation with DDI constraints
|
|
- Interaction-aware predictions
|
|
- Safety-focused drug selection
|
|
|
|
**SafeDrug** (`SafeDrug`)
|
|
- Safety-aware drug recommendation
|
|
- Molecular structure integration
|
|
- DDI constraint optimization
|
|
- Balances efficacy and safety
|
|
|
|
**Key Features:**
|
|
- Molecular graph encoding
|
|
- DDI graph neural network
|
|
- Reinforcement learning for safety
|
|
- Published in KDD 2021
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import SafeDrug
|
|
|
|
model = SafeDrug(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="multilabel",
|
|
ddi_adj_path="/path/to/ddi_matrix.pkl",
|
|
molecule_path="/path/to/molecule_graphs.pkl"
|
|
)
|
|
```
|
|
|
|
**MoleRec** (`MoleRec`)
|
|
- Molecular-level drug recommendations
|
|
- Sub-structure reasoning
|
|
- Fine-grained medication selection
|
|
|
|
### Disease Progression Models
|
|
|
|
**StageNet** (`StageNet`)
|
|
- Disease stage-aware prediction
|
|
- Learns clinical stages automatically
|
|
- Stage-adaptive feature extraction
|
|
- Effective for chronic disease monitoring
|
|
|
|
**Architecture:**
|
|
- Stage-aware LSTM
|
|
- Dynamic stage transitions
|
|
- Time-decay mechanism
|
|
|
|
**Usage:**
|
|
```python
|
|
from pyhealth.models import StageNet
|
|
|
|
model = StageNet(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications"],
|
|
mode="binary",
|
|
hidden_dim=128,
|
|
num_stages=3,
|
|
chunk_size=128
|
|
)
|
|
```
|
|
|
|
**Best for:**
|
|
- ICU mortality prediction
|
|
- Chronic disease progression
|
|
- Time-varying risk assessment
|
|
|
|
**Deepr** (`Deepr`)
|
|
- Deep recurrent architecture
|
|
- Medical concept embeddings
|
|
- Temporal pattern learning
|
|
- Published in JAMIA
|
|
|
|
### Advanced Sequential Models
|
|
|
|
**Agent** (`Agent`)
|
|
- Reinforcement learning-based
|
|
- Treatment recommendation
|
|
- Action-value optimization
|
|
- Policy learning for sequential decisions
|
|
|
|
**GRASP** (`GRASP`)
|
|
- Graph-based sequence patterns
|
|
- Structural event relationships
|
|
- Hierarchical representation learning
|
|
|
|
**SparcNet** (`SparcNet`)
|
|
- Sparse clinical networks
|
|
- Efficient feature selection
|
|
- Reduced computational cost
|
|
- Interpretable predictions
|
|
|
|
**ContraWR** (`ContraWR`)
|
|
- Contrastive learning approach
|
|
- Self-supervised pre-training
|
|
- Robust representations
|
|
- Limited labeled data scenarios
|
|
|
|
### Medical Entity Linking
|
|
|
|
**MedLink** (`MedLink`)
|
|
- Medical entity linking to knowledge bases
|
|
- Clinical concept normalization
|
|
- UMLS integration
|
|
- Entity disambiguation
|
|
|
|
### Generative Models
|
|
|
|
**GAN** (`GAN`)
|
|
- Generative Adversarial Networks
|
|
- Synthetic EHR data generation
|
|
- Privacy-preserving data sharing
|
|
- Augmentation for rare conditions
|
|
|
|
**VAE** (`VAE`)
|
|
- Variational Autoencoder
|
|
- Patient representation learning
|
|
- Anomaly detection
|
|
- Latent space exploration
|
|
|
|
### Social Determinants of Health
|
|
|
|
**SDOH** (`SDOH`)
|
|
- Social determinants integration
|
|
- Multi-modal prediction
|
|
- Addresses health disparities
|
|
- Combines clinical and social data
|
|
|
|
## Model Selection Guidelines
|
|
|
|
### By Task Type
|
|
|
|
**Binary Classification** (Mortality, Readmission)
|
|
- Start with: Logistic Regression (baseline)
|
|
- Standard: RNN, Transformer
|
|
- Interpretable: RETAIN, AdaCare
|
|
- Advanced: StageNet
|
|
|
|
**Multi-Label Classification** (Drug Recommendation)
|
|
- Standard: CNN, RNN
|
|
- Healthcare-specific: GAMENet, SafeDrug, MICRON, MoleRec
|
|
- Graph-based: GNN
|
|
|
|
**Regression** (Length of Stay)
|
|
- Start with: MLP (baseline)
|
|
- Sequential: RNN, TCN
|
|
- Advanced: Transformer
|
|
|
|
**Multi-Class Classification** (Medical Coding, Specialty)
|
|
- Standard: CNN, RNN, Transformer
|
|
- Text-based: TransformersModel (BERT variants)
|
|
|
|
### By Data Type
|
|
|
|
**Sequential Events** (Diagnoses, Medications, Procedures)
|
|
- RNN, LSTM, GRU
|
|
- Transformer
|
|
- RETAIN, AdaCare, ConCare
|
|
|
|
**Time-Series Signals** (EEG, ECG)
|
|
- CNN, TCN
|
|
- RNN
|
|
- Transformer
|
|
|
|
**Text** (Clinical Notes)
|
|
- TransformersModel (ClinicalBERT, BioBERT)
|
|
- CNN for shorter text
|
|
- RNN for sequential text
|
|
|
|
**Graphs** (Drug Interactions, Patient Networks)
|
|
- GNN (GAT, GCN)
|
|
- GAMENet, SafeDrug
|
|
|
|
**Images** (X-rays, CT scans)
|
|
- CNN (ResNet, DenseNet via TransformersModel)
|
|
- Vision Transformers
|
|
|
|
### By Interpretability Needs
|
|
|
|
**High Interpretability Required:**
|
|
- Logistic Regression
|
|
- RETAIN
|
|
- AdaCare
|
|
- SparcNet
|
|
|
|
**Moderate Interpretability:**
|
|
- CNN (filter visualization)
|
|
- Transformer (attention visualization)
|
|
- GNN (graph attention)
|
|
|
|
**Black-Box Acceptable:**
|
|
- Deep RNN models
|
|
- Complex ensembles
|
|
|
|
## Training Considerations
|
|
|
|
### Hyperparameter Tuning
|
|
|
|
**Embedding Dimension:**
|
|
- Small datasets: 64-128
|
|
- Large datasets: 128-256
|
|
- Complex tasks: 256-512
|
|
|
|
**Hidden Dimension:**
|
|
- Proportional to embedding_dim
|
|
- Typically 1-2x embedding_dim
|
|
|
|
**Number of Layers:**
|
|
- Start with 2-3 layers
|
|
- Deeper for complex patterns
|
|
- Watch for overfitting
|
|
|
|
**Dropout:**
|
|
- Start with 0.5
|
|
- Reduce if underfitting (0.1-0.3)
|
|
- Increase if overfitting (0.5-0.7)
|
|
|
|
### Computational Requirements
|
|
|
|
**Memory (GPU):**
|
|
- CNN: Low to moderate
|
|
- RNN: Moderate (sequence length dependent)
|
|
- Transformer: High (quadratic in sequence length)
|
|
- GNN: Moderate to high (graph size dependent)
|
|
|
|
**Training Speed:**
|
|
- Fastest: Logistic Regression, MLP, CNN
|
|
- Moderate: RNN, GNN
|
|
- Slower: Transformer (but parallelizable)
|
|
|
|
### Best Practices
|
|
|
|
1. **Start with simple baselines** (Logistic Regression, MLP)
|
|
2. **Use appropriate feature keys** based on data availability
|
|
3. **Match mode to task output** (binary, multiclass, multilabel, regression)
|
|
4. **Consider interpretability requirements** for clinical deployment
|
|
5. **Validate on held-out test set** for realistic performance
|
|
6. **Monitor for overfitting** especially with complex models
|
|
7. **Use pretrained models** when possible (TransformersModel)
|
|
8. **Consider computational constraints** for deployment
|
|
|
|
## Example Workflow
|
|
|
|
```python
|
|
from pyhealth.datasets import MIMIC4Dataset
|
|
from pyhealth.tasks import mortality_prediction_mimic4_fn
|
|
from pyhealth.models import Transformer
|
|
from pyhealth.trainer import Trainer
|
|
|
|
# 1. Prepare data
|
|
dataset = MIMIC4Dataset(root="/path/to/data")
|
|
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
|
|
|
|
# 2. Initialize model
|
|
model = Transformer(
|
|
dataset=sample_dataset,
|
|
feature_keys=["diagnoses", "medications", "procedures"],
|
|
mode="binary",
|
|
embedding_dim=128,
|
|
num_heads=8,
|
|
num_layers=3,
|
|
dropout=0.3
|
|
)
|
|
|
|
# 3. Train model
|
|
trainer = Trainer(model=model)
|
|
trainer.train(
|
|
train_dataloader=train_loader,
|
|
val_dataloader=val_loader,
|
|
epochs=50,
|
|
monitor="pr_auc_score",
|
|
monitor_criterion="max"
|
|
)
|
|
|
|
# 4. Evaluate
|
|
results = trainer.evaluate(test_loader)
|
|
print(results)
|
|
```
|