Files
2025-11-30 08:30:10 +08:00

595 lines
14 KiB
Markdown

# PyHealth Models
## Overview
PyHealth provides 33+ models for healthcare prediction tasks, ranging from simple baselines to state-of-the-art deep learning architectures. Models are organized into general-purpose architectures and healthcare-specific models.
## Model Base Class
All models inherit from `BaseModel` with standard PyTorch functionality:
**Key Attributes:**
- `dataset`: Associated SampleDataset
- `feature_keys`: Input features to use (e.g., ["diagnoses", "medications"])
- `mode`: Task type ("binary", "multiclass", "multilabel", "regression")
- `embedding_dim`: Feature embedding dimension
- `device`: Computation device (CPU/GPU)
**Key Methods:**
- `forward()`: Model forward pass
- `train_step()`: Single training iteration
- `eval_step()`: Single evaluation iteration
- `save()`: Save model checkpoint
- `load()`: Load model checkpoint
## General-Purpose Models
### Baseline Models
**Logistic Regression** (`LogisticRegression`)
- Linear classifier with mean pooling
- Simple baseline for comparison
- Fast training and inference
- Good for interpretability
**Usage:**
```python
from pyhealth.models import LogisticRegression
model = LogisticRegression(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary"
)
```
**Multi-Layer Perceptron** (`MLP`)
- Feedforward neural network
- Configurable hidden layers
- Supports mean/sum/max pooling
- Good baseline for structured data
**Parameters:**
- `hidden_dim`: Hidden layer size
- `num_layers`: Number of hidden layers
- `dropout`: Dropout rate
- `pooling`: Aggregation method ("mean", "sum", "max")
**Usage:**
```python
from pyhealth.models import MLP
model = MLP(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
hidden_dim=128,
num_layers=3,
dropout=0.5
)
```
### Convolutional Neural Networks
**CNN** (`CNN`)
- Convolutional layers for pattern detection
- Effective for sequential and spatial data
- Captures local temporal patterns
- Parameter efficient
**Architecture:**
- Multiple 1D convolutional layers
- Max pooling for dimension reduction
- Fully connected output layers
**Parameters:**
- `num_filters`: Number of convolutional filters
- `kernel_size`: Convolution kernel size
- `num_layers`: Number of conv layers
- `dropout`: Dropout rate
**Usage:**
```python
from pyhealth.models import CNN
model = CNN(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
num_filters=64,
kernel_size=3,
num_layers=3
)
```
**Temporal Convolutional Networks** (`TCN`)
- Dilated convolutions for long-range dependencies
- Causal convolutions (no future information leakage)
- Efficient for long sequences
- Good for time-series prediction
**Advantages:**
- Captures long-term dependencies
- Parallelizable (faster than RNNs)
- Stable gradients
### Recurrent Neural Networks
**RNN** (`RNN`)
- Basic recurrent architecture
- Supports LSTM, GRU, RNN variants
- Sequential processing
- Captures temporal dependencies
**Parameters:**
- `rnn_type`: "LSTM", "GRU", or "RNN"
- `hidden_dim`: Hidden state dimension
- `num_layers`: Number of recurrent layers
- `dropout`: Dropout rate
- `bidirectional`: Use bidirectional RNN
**Usage:**
```python
from pyhealth.models import RNN
model = RNN(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
rnn_type="LSTM",
hidden_dim=128,
num_layers=2,
bidirectional=True
)
```
**Best for:**
- Sequential clinical events
- Temporal pattern learning
- Variable-length sequences
### Transformer Models
**Transformer** (`Transformer`)
- Self-attention mechanism
- Parallel processing of sequences
- State-of-the-art performance
- Effective for long-range dependencies
**Architecture:**
- Multi-head self-attention
- Position embeddings
- Feed-forward networks
- Layer normalization
**Parameters:**
- `num_heads`: Number of attention heads
- `num_layers`: Number of transformer layers
- `hidden_dim`: Hidden dimension
- `dropout`: Dropout rate
- `max_seq_length`: Maximum sequence length
**Usage:**
```python
from pyhealth.models import Transformer
model = Transformer(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
num_heads=8,
num_layers=6,
hidden_dim=256,
dropout=0.1
)
```
**TransformersModel** (`TransformersModel`)
- Integration with HuggingFace transformers
- Pre-trained language models for clinical text
- Fine-tuning for healthcare tasks
- Examples: BERT, RoBERTa, BioClinicalBERT
**Usage:**
```python
from pyhealth.models import TransformersModel
model = TransformersModel(
dataset=sample_dataset,
feature_keys=["text"],
mode="multiclass",
pretrained_model="emilyalsentzer/Bio_ClinicalBERT"
)
```
### Graph Neural Networks
**GNN** (`GNN`)
- Graph-based learning
- Models relationships between entities
- Supports GAT (Graph Attention) and GCN (Graph Convolutional)
**Use Cases:**
- Drug-drug interactions
- Patient similarity networks
- Knowledge graph integration
- Comorbidity relationships
**Parameters:**
- `gnn_type`: "GAT" or "GCN"
- `hidden_dim`: Hidden dimension
- `num_layers`: Number of GNN layers
- `dropout`: Dropout rate
- `num_heads`: Attention heads (for GAT)
**Usage:**
```python
from pyhealth.models import GNN
model = GNN(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="multilabel",
gnn_type="GAT",
hidden_dim=128,
num_layers=3,
num_heads=4
)
```
## Healthcare-Specific Models
### Interpretable Clinical Models
**RETAIN** (`RETAIN`)
- Reverse time attention mechanism
- Highly interpretable predictions
- Visit-level and event-level attention
- Identifies influential clinical events
**Key Features:**
- Two-level attention (visits and features)
- Temporal decay modeling
- Clinically meaningful explanations
- Published in NeurIPS 2016
**Usage:**
```python
from pyhealth.models import RETAIN
model = RETAIN(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
hidden_dim=128
)
# Get attention weights for interpretation
outputs = model(batch)
visit_attention = outputs["visit_attention"]
feature_attention = outputs["feature_attention"]
```
**Best for:**
- Mortality prediction
- Readmission prediction
- Clinical risk scoring
- Interpretable predictions
**AdaCare** (`AdaCare`)
- Adaptive care model with feature calibration
- Disease-specific attention
- Handles irregular time intervals
- Interpretable feature importance
**ConCare** (`ConCare`)
- Cross-visit convolutional attention
- Temporal convolutional feature extraction
- Multi-level attention mechanism
- Good for longitudinal EHR modeling
### Medication Recommendation Models
**GAMENet** (`GAMENet`)
- Graph-based medication recommendation
- Drug-drug interaction modeling
- Memory network for patient history
- Multi-hop reasoning
**Architecture:**
- Drug knowledge graph
- Memory-augmented neural network
- DDI-aware prediction
**Usage:**
```python
from pyhealth.models import GAMENet
model = GAMENet(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="multilabel",
embedding_dim=128,
ddi_adj_path="/path/to/ddi_adjacency_matrix.pkl"
)
```
**MICRON** (`MICRON`)
- Medication recommendation with DDI constraints
- Interaction-aware predictions
- Safety-focused drug selection
**SafeDrug** (`SafeDrug`)
- Safety-aware drug recommendation
- Molecular structure integration
- DDI constraint optimization
- Balances efficacy and safety
**Key Features:**
- Molecular graph encoding
- DDI graph neural network
- Reinforcement learning for safety
- Published in KDD 2021
**Usage:**
```python
from pyhealth.models import SafeDrug
model = SafeDrug(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="multilabel",
ddi_adj_path="/path/to/ddi_matrix.pkl",
molecule_path="/path/to/molecule_graphs.pkl"
)
```
**MoleRec** (`MoleRec`)
- Molecular-level drug recommendations
- Sub-structure reasoning
- Fine-grained medication selection
### Disease Progression Models
**StageNet** (`StageNet`)
- Disease stage-aware prediction
- Learns clinical stages automatically
- Stage-adaptive feature extraction
- Effective for chronic disease monitoring
**Architecture:**
- Stage-aware LSTM
- Dynamic stage transitions
- Time-decay mechanism
**Usage:**
```python
from pyhealth.models import StageNet
model = StageNet(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
hidden_dim=128,
num_stages=3,
chunk_size=128
)
```
**Best for:**
- ICU mortality prediction
- Chronic disease progression
- Time-varying risk assessment
**Deepr** (`Deepr`)
- Deep recurrent architecture
- Medical concept embeddings
- Temporal pattern learning
- Published in JAMIA
### Advanced Sequential Models
**Agent** (`Agent`)
- Reinforcement learning-based
- Treatment recommendation
- Action-value optimization
- Policy learning for sequential decisions
**GRASP** (`GRASP`)
- Graph-based sequence patterns
- Structural event relationships
- Hierarchical representation learning
**SparcNet** (`SparcNet`)
- Sparse clinical networks
- Efficient feature selection
- Reduced computational cost
- Interpretable predictions
**ContraWR** (`ContraWR`)
- Contrastive learning approach
- Self-supervised pre-training
- Robust representations
- Limited labeled data scenarios
### Medical Entity Linking
**MedLink** (`MedLink`)
- Medical entity linking to knowledge bases
- Clinical concept normalization
- UMLS integration
- Entity disambiguation
### Generative Models
**GAN** (`GAN`)
- Generative Adversarial Networks
- Synthetic EHR data generation
- Privacy-preserving data sharing
- Augmentation for rare conditions
**VAE** (`VAE`)
- Variational Autoencoder
- Patient representation learning
- Anomaly detection
- Latent space exploration
### Social Determinants of Health
**SDOH** (`SDOH`)
- Social determinants integration
- Multi-modal prediction
- Addresses health disparities
- Combines clinical and social data
## Model Selection Guidelines
### By Task Type
**Binary Classification** (Mortality, Readmission)
- Start with: Logistic Regression (baseline)
- Standard: RNN, Transformer
- Interpretable: RETAIN, AdaCare
- Advanced: StageNet
**Multi-Label Classification** (Drug Recommendation)
- Standard: CNN, RNN
- Healthcare-specific: GAMENet, SafeDrug, MICRON, MoleRec
- Graph-based: GNN
**Regression** (Length of Stay)
- Start with: MLP (baseline)
- Sequential: RNN, TCN
- Advanced: Transformer
**Multi-Class Classification** (Medical Coding, Specialty)
- Standard: CNN, RNN, Transformer
- Text-based: TransformersModel (BERT variants)
### By Data Type
**Sequential Events** (Diagnoses, Medications, Procedures)
- RNN, LSTM, GRU
- Transformer
- RETAIN, AdaCare, ConCare
**Time-Series Signals** (EEG, ECG)
- CNN, TCN
- RNN
- Transformer
**Text** (Clinical Notes)
- TransformersModel (ClinicalBERT, BioBERT)
- CNN for shorter text
- RNN for sequential text
**Graphs** (Drug Interactions, Patient Networks)
- GNN (GAT, GCN)
- GAMENet, SafeDrug
**Images** (X-rays, CT scans)
- CNN (ResNet, DenseNet via TransformersModel)
- Vision Transformers
### By Interpretability Needs
**High Interpretability Required:**
- Logistic Regression
- RETAIN
- AdaCare
- SparcNet
**Moderate Interpretability:**
- CNN (filter visualization)
- Transformer (attention visualization)
- GNN (graph attention)
**Black-Box Acceptable:**
- Deep RNN models
- Complex ensembles
## Training Considerations
### Hyperparameter Tuning
**Embedding Dimension:**
- Small datasets: 64-128
- Large datasets: 128-256
- Complex tasks: 256-512
**Hidden Dimension:**
- Proportional to embedding_dim
- Typically 1-2x embedding_dim
**Number of Layers:**
- Start with 2-3 layers
- Deeper for complex patterns
- Watch for overfitting
**Dropout:**
- Start with 0.5
- Reduce if underfitting (0.1-0.3)
- Increase if overfitting (0.5-0.7)
### Computational Requirements
**Memory (GPU):**
- CNN: Low to moderate
- RNN: Moderate (sequence length dependent)
- Transformer: High (quadratic in sequence length)
- GNN: Moderate to high (graph size dependent)
**Training Speed:**
- Fastest: Logistic Regression, MLP, CNN
- Moderate: RNN, GNN
- Slower: Transformer (but parallelizable)
### Best Practices
1. **Start with simple baselines** (Logistic Regression, MLP)
2. **Use appropriate feature keys** based on data availability
3. **Match mode to task output** (binary, multiclass, multilabel, regression)
4. **Consider interpretability requirements** for clinical deployment
5. **Validate on held-out test set** for realistic performance
6. **Monitor for overfitting** especially with complex models
7. **Use pretrained models** when possible (TransformersModel)
8. **Consider computational constraints** for deployment
## Example Workflow
```python
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer
# 1. Prepare data
dataset = MIMIC4Dataset(root="/path/to/data")
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
# 2. Initialize model
model = Transformer(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications", "procedures"],
mode="binary",
embedding_dim=128,
num_heads=8,
num_layers=3,
dropout=0.3
)
# 3. Train model
trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=50,
monitor="pr_auc_score",
monitor_criterion="max"
)
# 4. Evaluate
results = trainer.evaluate(test_loader)
print(results)
```