Initial commit
This commit is contained in:
594
skills/pyhealth/references/models.md
Normal file
594
skills/pyhealth/references/models.md
Normal file
@@ -0,0 +1,594 @@
|
||||
# PyHealth Models
|
||||
|
||||
## Overview
|
||||
|
||||
PyHealth provides 33+ models for healthcare prediction tasks, ranging from simple baselines to state-of-the-art deep learning architectures. Models are organized into general-purpose architectures and healthcare-specific models.
|
||||
|
||||
## Model Base Class
|
||||
|
||||
All models inherit from `BaseModel` with standard PyTorch functionality:
|
||||
|
||||
**Key Attributes:**
|
||||
- `dataset`: Associated SampleDataset
|
||||
- `feature_keys`: Input features to use (e.g., ["diagnoses", "medications"])
|
||||
- `mode`: Task type ("binary", "multiclass", "multilabel", "regression")
|
||||
- `embedding_dim`: Feature embedding dimension
|
||||
- `device`: Computation device (CPU/GPU)
|
||||
|
||||
**Key Methods:**
|
||||
- `forward()`: Model forward pass
|
||||
- `train_step()`: Single training iteration
|
||||
- `eval_step()`: Single evaluation iteration
|
||||
- `save()`: Save model checkpoint
|
||||
- `load()`: Load model checkpoint
|
||||
|
||||
## General-Purpose Models
|
||||
|
||||
### Baseline Models
|
||||
|
||||
**Logistic Regression** (`LogisticRegression`)
|
||||
- Linear classifier with mean pooling
|
||||
- Simple baseline for comparison
|
||||
- Fast training and inference
|
||||
- Good for interpretability
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import LogisticRegression
|
||||
|
||||
model = LogisticRegression(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary"
|
||||
)
|
||||
```
|
||||
|
||||
**Multi-Layer Perceptron** (`MLP`)
|
||||
- Feedforward neural network
|
||||
- Configurable hidden layers
|
||||
- Supports mean/sum/max pooling
|
||||
- Good baseline for structured data
|
||||
|
||||
**Parameters:**
|
||||
- `hidden_dim`: Hidden layer size
|
||||
- `num_layers`: Number of hidden layers
|
||||
- `dropout`: Dropout rate
|
||||
- `pooling`: Aggregation method ("mean", "sum", "max")
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import MLP
|
||||
|
||||
model = MLP(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
hidden_dim=128,
|
||||
num_layers=3,
|
||||
dropout=0.5
|
||||
)
|
||||
```
|
||||
|
||||
### Convolutional Neural Networks
|
||||
|
||||
**CNN** (`CNN`)
|
||||
- Convolutional layers for pattern detection
|
||||
- Effective for sequential and spatial data
|
||||
- Captures local temporal patterns
|
||||
- Parameter efficient
|
||||
|
||||
**Architecture:**
|
||||
- Multiple 1D convolutional layers
|
||||
- Max pooling for dimension reduction
|
||||
- Fully connected output layers
|
||||
|
||||
**Parameters:**
|
||||
- `num_filters`: Number of convolutional filters
|
||||
- `kernel_size`: Convolution kernel size
|
||||
- `num_layers`: Number of conv layers
|
||||
- `dropout`: Dropout rate
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import CNN
|
||||
|
||||
model = CNN(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
num_filters=64,
|
||||
kernel_size=3,
|
||||
num_layers=3
|
||||
)
|
||||
```
|
||||
|
||||
**Temporal Convolutional Networks** (`TCN`)
|
||||
- Dilated convolutions for long-range dependencies
|
||||
- Causal convolutions (no future information leakage)
|
||||
- Efficient for long sequences
|
||||
- Good for time-series prediction
|
||||
|
||||
**Advantages:**
|
||||
- Captures long-term dependencies
|
||||
- Parallelizable (faster than RNNs)
|
||||
- Stable gradients
|
||||
|
||||
### Recurrent Neural Networks
|
||||
|
||||
**RNN** (`RNN`)
|
||||
- Basic recurrent architecture
|
||||
- Supports LSTM, GRU, RNN variants
|
||||
- Sequential processing
|
||||
- Captures temporal dependencies
|
||||
|
||||
**Parameters:**
|
||||
- `rnn_type`: "LSTM", "GRU", or "RNN"
|
||||
- `hidden_dim`: Hidden state dimension
|
||||
- `num_layers`: Number of recurrent layers
|
||||
- `dropout`: Dropout rate
|
||||
- `bidirectional`: Use bidirectional RNN
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import RNN
|
||||
|
||||
model = RNN(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
rnn_type="LSTM",
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
bidirectional=True
|
||||
)
|
||||
```
|
||||
|
||||
**Best for:**
|
||||
- Sequential clinical events
|
||||
- Temporal pattern learning
|
||||
- Variable-length sequences
|
||||
|
||||
### Transformer Models
|
||||
|
||||
**Transformer** (`Transformer`)
|
||||
- Self-attention mechanism
|
||||
- Parallel processing of sequences
|
||||
- State-of-the-art performance
|
||||
- Effective for long-range dependencies
|
||||
|
||||
**Architecture:**
|
||||
- Multi-head self-attention
|
||||
- Position embeddings
|
||||
- Feed-forward networks
|
||||
- Layer normalization
|
||||
|
||||
**Parameters:**
|
||||
- `num_heads`: Number of attention heads
|
||||
- `num_layers`: Number of transformer layers
|
||||
- `hidden_dim`: Hidden dimension
|
||||
- `dropout`: Dropout rate
|
||||
- `max_seq_length`: Maximum sequence length
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import Transformer
|
||||
|
||||
model = Transformer(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
num_heads=8,
|
||||
num_layers=6,
|
||||
hidden_dim=256,
|
||||
dropout=0.1
|
||||
)
|
||||
```
|
||||
|
||||
**TransformersModel** (`TransformersModel`)
|
||||
- Integration with HuggingFace transformers
|
||||
- Pre-trained language models for clinical text
|
||||
- Fine-tuning for healthcare tasks
|
||||
- Examples: BERT, RoBERTa, BioClinicalBERT
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import TransformersModel
|
||||
|
||||
model = TransformersModel(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["text"],
|
||||
mode="multiclass",
|
||||
pretrained_model="emilyalsentzer/Bio_ClinicalBERT"
|
||||
)
|
||||
```
|
||||
|
||||
### Graph Neural Networks
|
||||
|
||||
**GNN** (`GNN`)
|
||||
- Graph-based learning
|
||||
- Models relationships between entities
|
||||
- Supports GAT (Graph Attention) and GCN (Graph Convolutional)
|
||||
|
||||
**Use Cases:**
|
||||
- Drug-drug interactions
|
||||
- Patient similarity networks
|
||||
- Knowledge graph integration
|
||||
- Comorbidity relationships
|
||||
|
||||
**Parameters:**
|
||||
- `gnn_type`: "GAT" or "GCN"
|
||||
- `hidden_dim`: Hidden dimension
|
||||
- `num_layers`: Number of GNN layers
|
||||
- `dropout`: Dropout rate
|
||||
- `num_heads`: Attention heads (for GAT)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import GNN
|
||||
|
||||
model = GNN(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="multilabel",
|
||||
gnn_type="GAT",
|
||||
hidden_dim=128,
|
||||
num_layers=3,
|
||||
num_heads=4
|
||||
)
|
||||
```
|
||||
|
||||
## Healthcare-Specific Models
|
||||
|
||||
### Interpretable Clinical Models
|
||||
|
||||
**RETAIN** (`RETAIN`)
|
||||
- Reverse time attention mechanism
|
||||
- Highly interpretable predictions
|
||||
- Visit-level and event-level attention
|
||||
- Identifies influential clinical events
|
||||
|
||||
**Key Features:**
|
||||
- Two-level attention (visits and features)
|
||||
- Temporal decay modeling
|
||||
- Clinically meaningful explanations
|
||||
- Published in NeurIPS 2016
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import RETAIN
|
||||
|
||||
model = RETAIN(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
hidden_dim=128
|
||||
)
|
||||
|
||||
# Get attention weights for interpretation
|
||||
outputs = model(batch)
|
||||
visit_attention = outputs["visit_attention"]
|
||||
feature_attention = outputs["feature_attention"]
|
||||
```
|
||||
|
||||
**Best for:**
|
||||
- Mortality prediction
|
||||
- Readmission prediction
|
||||
- Clinical risk scoring
|
||||
- Interpretable predictions
|
||||
|
||||
**AdaCare** (`AdaCare`)
|
||||
- Adaptive care model with feature calibration
|
||||
- Disease-specific attention
|
||||
- Handles irregular time intervals
|
||||
- Interpretable feature importance
|
||||
|
||||
**ConCare** (`ConCare`)
|
||||
- Cross-visit convolutional attention
|
||||
- Temporal convolutional feature extraction
|
||||
- Multi-level attention mechanism
|
||||
- Good for longitudinal EHR modeling
|
||||
|
||||
### Medication Recommendation Models
|
||||
|
||||
**GAMENet** (`GAMENet`)
|
||||
- Graph-based medication recommendation
|
||||
- Drug-drug interaction modeling
|
||||
- Memory network for patient history
|
||||
- Multi-hop reasoning
|
||||
|
||||
**Architecture:**
|
||||
- Drug knowledge graph
|
||||
- Memory-augmented neural network
|
||||
- DDI-aware prediction
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import GAMENet
|
||||
|
||||
model = GAMENet(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="multilabel",
|
||||
embedding_dim=128,
|
||||
ddi_adj_path="/path/to/ddi_adjacency_matrix.pkl"
|
||||
)
|
||||
```
|
||||
|
||||
**MICRON** (`MICRON`)
|
||||
- Medication recommendation with DDI constraints
|
||||
- Interaction-aware predictions
|
||||
- Safety-focused drug selection
|
||||
|
||||
**SafeDrug** (`SafeDrug`)
|
||||
- Safety-aware drug recommendation
|
||||
- Molecular structure integration
|
||||
- DDI constraint optimization
|
||||
- Balances efficacy and safety
|
||||
|
||||
**Key Features:**
|
||||
- Molecular graph encoding
|
||||
- DDI graph neural network
|
||||
- Reinforcement learning for safety
|
||||
- Published in KDD 2021
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import SafeDrug
|
||||
|
||||
model = SafeDrug(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="multilabel",
|
||||
ddi_adj_path="/path/to/ddi_matrix.pkl",
|
||||
molecule_path="/path/to/molecule_graphs.pkl"
|
||||
)
|
||||
```
|
||||
|
||||
**MoleRec** (`MoleRec`)
|
||||
- Molecular-level drug recommendations
|
||||
- Sub-structure reasoning
|
||||
- Fine-grained medication selection
|
||||
|
||||
### Disease Progression Models
|
||||
|
||||
**StageNet** (`StageNet`)
|
||||
- Disease stage-aware prediction
|
||||
- Learns clinical stages automatically
|
||||
- Stage-adaptive feature extraction
|
||||
- Effective for chronic disease monitoring
|
||||
|
||||
**Architecture:**
|
||||
- Stage-aware LSTM
|
||||
- Dynamic stage transitions
|
||||
- Time-decay mechanism
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from pyhealth.models import StageNet
|
||||
|
||||
model = StageNet(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications"],
|
||||
mode="binary",
|
||||
hidden_dim=128,
|
||||
num_stages=3,
|
||||
chunk_size=128
|
||||
)
|
||||
```
|
||||
|
||||
**Best for:**
|
||||
- ICU mortality prediction
|
||||
- Chronic disease progression
|
||||
- Time-varying risk assessment
|
||||
|
||||
**Deepr** (`Deepr`)
|
||||
- Deep recurrent architecture
|
||||
- Medical concept embeddings
|
||||
- Temporal pattern learning
|
||||
- Published in JAMIA
|
||||
|
||||
### Advanced Sequential Models
|
||||
|
||||
**Agent** (`Agent`)
|
||||
- Reinforcement learning-based
|
||||
- Treatment recommendation
|
||||
- Action-value optimization
|
||||
- Policy learning for sequential decisions
|
||||
|
||||
**GRASP** (`GRASP`)
|
||||
- Graph-based sequence patterns
|
||||
- Structural event relationships
|
||||
- Hierarchical representation learning
|
||||
|
||||
**SparcNet** (`SparcNet`)
|
||||
- Sparse clinical networks
|
||||
- Efficient feature selection
|
||||
- Reduced computational cost
|
||||
- Interpretable predictions
|
||||
|
||||
**ContraWR** (`ContraWR`)
|
||||
- Contrastive learning approach
|
||||
- Self-supervised pre-training
|
||||
- Robust representations
|
||||
- Limited labeled data scenarios
|
||||
|
||||
### Medical Entity Linking
|
||||
|
||||
**MedLink** (`MedLink`)
|
||||
- Medical entity linking to knowledge bases
|
||||
- Clinical concept normalization
|
||||
- UMLS integration
|
||||
- Entity disambiguation
|
||||
|
||||
### Generative Models
|
||||
|
||||
**GAN** (`GAN`)
|
||||
- Generative Adversarial Networks
|
||||
- Synthetic EHR data generation
|
||||
- Privacy-preserving data sharing
|
||||
- Augmentation for rare conditions
|
||||
|
||||
**VAE** (`VAE`)
|
||||
- Variational Autoencoder
|
||||
- Patient representation learning
|
||||
- Anomaly detection
|
||||
- Latent space exploration
|
||||
|
||||
### Social Determinants of Health
|
||||
|
||||
**SDOH** (`SDOH`)
|
||||
- Social determinants integration
|
||||
- Multi-modal prediction
|
||||
- Addresses health disparities
|
||||
- Combines clinical and social data
|
||||
|
||||
## Model Selection Guidelines
|
||||
|
||||
### By Task Type
|
||||
|
||||
**Binary Classification** (Mortality, Readmission)
|
||||
- Start with: Logistic Regression (baseline)
|
||||
- Standard: RNN, Transformer
|
||||
- Interpretable: RETAIN, AdaCare
|
||||
- Advanced: StageNet
|
||||
|
||||
**Multi-Label Classification** (Drug Recommendation)
|
||||
- Standard: CNN, RNN
|
||||
- Healthcare-specific: GAMENet, SafeDrug, MICRON, MoleRec
|
||||
- Graph-based: GNN
|
||||
|
||||
**Regression** (Length of Stay)
|
||||
- Start with: MLP (baseline)
|
||||
- Sequential: RNN, TCN
|
||||
- Advanced: Transformer
|
||||
|
||||
**Multi-Class Classification** (Medical Coding, Specialty)
|
||||
- Standard: CNN, RNN, Transformer
|
||||
- Text-based: TransformersModel (BERT variants)
|
||||
|
||||
### By Data Type
|
||||
|
||||
**Sequential Events** (Diagnoses, Medications, Procedures)
|
||||
- RNN, LSTM, GRU
|
||||
- Transformer
|
||||
- RETAIN, AdaCare, ConCare
|
||||
|
||||
**Time-Series Signals** (EEG, ECG)
|
||||
- CNN, TCN
|
||||
- RNN
|
||||
- Transformer
|
||||
|
||||
**Text** (Clinical Notes)
|
||||
- TransformersModel (ClinicalBERT, BioBERT)
|
||||
- CNN for shorter text
|
||||
- RNN for sequential text
|
||||
|
||||
**Graphs** (Drug Interactions, Patient Networks)
|
||||
- GNN (GAT, GCN)
|
||||
- GAMENet, SafeDrug
|
||||
|
||||
**Images** (X-rays, CT scans)
|
||||
- CNN (ResNet, DenseNet via TransformersModel)
|
||||
- Vision Transformers
|
||||
|
||||
### By Interpretability Needs
|
||||
|
||||
**High Interpretability Required:**
|
||||
- Logistic Regression
|
||||
- RETAIN
|
||||
- AdaCare
|
||||
- SparcNet
|
||||
|
||||
**Moderate Interpretability:**
|
||||
- CNN (filter visualization)
|
||||
- Transformer (attention visualization)
|
||||
- GNN (graph attention)
|
||||
|
||||
**Black-Box Acceptable:**
|
||||
- Deep RNN models
|
||||
- Complex ensembles
|
||||
|
||||
## Training Considerations
|
||||
|
||||
### Hyperparameter Tuning
|
||||
|
||||
**Embedding Dimension:**
|
||||
- Small datasets: 64-128
|
||||
- Large datasets: 128-256
|
||||
- Complex tasks: 256-512
|
||||
|
||||
**Hidden Dimension:**
|
||||
- Proportional to embedding_dim
|
||||
- Typically 1-2x embedding_dim
|
||||
|
||||
**Number of Layers:**
|
||||
- Start with 2-3 layers
|
||||
- Deeper for complex patterns
|
||||
- Watch for overfitting
|
||||
|
||||
**Dropout:**
|
||||
- Start with 0.5
|
||||
- Reduce if underfitting (0.1-0.3)
|
||||
- Increase if overfitting (0.5-0.7)
|
||||
|
||||
### Computational Requirements
|
||||
|
||||
**Memory (GPU):**
|
||||
- CNN: Low to moderate
|
||||
- RNN: Moderate (sequence length dependent)
|
||||
- Transformer: High (quadratic in sequence length)
|
||||
- GNN: Moderate to high (graph size dependent)
|
||||
|
||||
**Training Speed:**
|
||||
- Fastest: Logistic Regression, MLP, CNN
|
||||
- Moderate: RNN, GNN
|
||||
- Slower: Transformer (but parallelizable)
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Start with simple baselines** (Logistic Regression, MLP)
|
||||
2. **Use appropriate feature keys** based on data availability
|
||||
3. **Match mode to task output** (binary, multiclass, multilabel, regression)
|
||||
4. **Consider interpretability requirements** for clinical deployment
|
||||
5. **Validate on held-out test set** for realistic performance
|
||||
6. **Monitor for overfitting** especially with complex models
|
||||
7. **Use pretrained models** when possible (TransformersModel)
|
||||
8. **Consider computational constraints** for deployment
|
||||
|
||||
## Example Workflow
|
||||
|
||||
```python
|
||||
from pyhealth.datasets import MIMIC4Dataset
|
||||
from pyhealth.tasks import mortality_prediction_mimic4_fn
|
||||
from pyhealth.models import Transformer
|
||||
from pyhealth.trainer import Trainer
|
||||
|
||||
# 1. Prepare data
|
||||
dataset = MIMIC4Dataset(root="/path/to/data")
|
||||
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
|
||||
|
||||
# 2. Initialize model
|
||||
model = Transformer(
|
||||
dataset=sample_dataset,
|
||||
feature_keys=["diagnoses", "medications", "procedures"],
|
||||
mode="binary",
|
||||
embedding_dim=128,
|
||||
num_heads=8,
|
||||
num_layers=3,
|
||||
dropout=0.3
|
||||
)
|
||||
|
||||
# 3. Train model
|
||||
trainer = Trainer(model=model)
|
||||
trainer.train(
|
||||
train_dataloader=train_loader,
|
||||
val_dataloader=val_loader,
|
||||
epochs=50,
|
||||
monitor="pr_auc_score",
|
||||
monitor_criterion="max"
|
||||
)
|
||||
|
||||
# 4. Evaluate
|
||||
results = trainer.evaluate(test_loader)
|
||||
print(results)
|
||||
```
|
||||
Reference in New Issue
Block a user