14 KiB
PyHealth Models
Overview
PyHealth provides 33+ models for healthcare prediction tasks, ranging from simple baselines to state-of-the-art deep learning architectures. Models are organized into general-purpose architectures and healthcare-specific models.
Model Base Class
All models inherit from BaseModel with standard PyTorch functionality:
Key Attributes:
dataset: Associated SampleDatasetfeature_keys: Input features to use (e.g., ["diagnoses", "medications"])mode: Task type ("binary", "multiclass", "multilabel", "regression")embedding_dim: Feature embedding dimensiondevice: Computation device (CPU/GPU)
Key Methods:
forward(): Model forward passtrain_step(): Single training iterationeval_step(): Single evaluation iterationsave(): Save model checkpointload(): Load model checkpoint
General-Purpose Models
Baseline Models
Logistic Regression (LogisticRegression)
- Linear classifier with mean pooling
- Simple baseline for comparison
- Fast training and inference
- Good for interpretability
Usage:
from pyhealth.models import LogisticRegression
model = LogisticRegression(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary"
)
Multi-Layer Perceptron (MLP)
- Feedforward neural network
- Configurable hidden layers
- Supports mean/sum/max pooling
- Good baseline for structured data
Parameters:
hidden_dim: Hidden layer sizenum_layers: Number of hidden layersdropout: Dropout ratepooling: Aggregation method ("mean", "sum", "max")
Usage:
from pyhealth.models import MLP
model = MLP(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
hidden_dim=128,
num_layers=3,
dropout=0.5
)
Convolutional Neural Networks
CNN (CNN)
- Convolutional layers for pattern detection
- Effective for sequential and spatial data
- Captures local temporal patterns
- Parameter efficient
Architecture:
- Multiple 1D convolutional layers
- Max pooling for dimension reduction
- Fully connected output layers
Parameters:
num_filters: Number of convolutional filterskernel_size: Convolution kernel sizenum_layers: Number of conv layersdropout: Dropout rate
Usage:
from pyhealth.models import CNN
model = CNN(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
num_filters=64,
kernel_size=3,
num_layers=3
)
Temporal Convolutional Networks (TCN)
- Dilated convolutions for long-range dependencies
- Causal convolutions (no future information leakage)
- Efficient for long sequences
- Good for time-series prediction
Advantages:
- Captures long-term dependencies
- Parallelizable (faster than RNNs)
- Stable gradients
Recurrent Neural Networks
RNN (RNN)
- Basic recurrent architecture
- Supports LSTM, GRU, RNN variants
- Sequential processing
- Captures temporal dependencies
Parameters:
rnn_type: "LSTM", "GRU", or "RNN"hidden_dim: Hidden state dimensionnum_layers: Number of recurrent layersdropout: Dropout ratebidirectional: Use bidirectional RNN
Usage:
from pyhealth.models import RNN
model = RNN(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
rnn_type="LSTM",
hidden_dim=128,
num_layers=2,
bidirectional=True
)
Best for:
- Sequential clinical events
- Temporal pattern learning
- Variable-length sequences
Transformer Models
Transformer (Transformer)
- Self-attention mechanism
- Parallel processing of sequences
- State-of-the-art performance
- Effective for long-range dependencies
Architecture:
- Multi-head self-attention
- Position embeddings
- Feed-forward networks
- Layer normalization
Parameters:
num_heads: Number of attention headsnum_layers: Number of transformer layershidden_dim: Hidden dimensiondropout: Dropout ratemax_seq_length: Maximum sequence length
Usage:
from pyhealth.models import Transformer
model = Transformer(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
num_heads=8,
num_layers=6,
hidden_dim=256,
dropout=0.1
)
TransformersModel (TransformersModel)
- Integration with HuggingFace transformers
- Pre-trained language models for clinical text
- Fine-tuning for healthcare tasks
- Examples: BERT, RoBERTa, BioClinicalBERT
Usage:
from pyhealth.models import TransformersModel
model = TransformersModel(
dataset=sample_dataset,
feature_keys=["text"],
mode="multiclass",
pretrained_model="emilyalsentzer/Bio_ClinicalBERT"
)
Graph Neural Networks
GNN (GNN)
- Graph-based learning
- Models relationships between entities
- Supports GAT (Graph Attention) and GCN (Graph Convolutional)
Use Cases:
- Drug-drug interactions
- Patient similarity networks
- Knowledge graph integration
- Comorbidity relationships
Parameters:
gnn_type: "GAT" or "GCN"hidden_dim: Hidden dimensionnum_layers: Number of GNN layersdropout: Dropout ratenum_heads: Attention heads (for GAT)
Usage:
from pyhealth.models import GNN
model = GNN(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="multilabel",
gnn_type="GAT",
hidden_dim=128,
num_layers=3,
num_heads=4
)
Healthcare-Specific Models
Interpretable Clinical Models
RETAIN (RETAIN)
- Reverse time attention mechanism
- Highly interpretable predictions
- Visit-level and event-level attention
- Identifies influential clinical events
Key Features:
- Two-level attention (visits and features)
- Temporal decay modeling
- Clinically meaningful explanations
- Published in NeurIPS 2016
Usage:
from pyhealth.models import RETAIN
model = RETAIN(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
hidden_dim=128
)
# Get attention weights for interpretation
outputs = model(batch)
visit_attention = outputs["visit_attention"]
feature_attention = outputs["feature_attention"]
Best for:
- Mortality prediction
- Readmission prediction
- Clinical risk scoring
- Interpretable predictions
AdaCare (AdaCare)
- Adaptive care model with feature calibration
- Disease-specific attention
- Handles irregular time intervals
- Interpretable feature importance
ConCare (ConCare)
- Cross-visit convolutional attention
- Temporal convolutional feature extraction
- Multi-level attention mechanism
- Good for longitudinal EHR modeling
Medication Recommendation Models
GAMENet (GAMENet)
- Graph-based medication recommendation
- Drug-drug interaction modeling
- Memory network for patient history
- Multi-hop reasoning
Architecture:
- Drug knowledge graph
- Memory-augmented neural network
- DDI-aware prediction
Usage:
from pyhealth.models import GAMENet
model = GAMENet(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="multilabel",
embedding_dim=128,
ddi_adj_path="/path/to/ddi_adjacency_matrix.pkl"
)
MICRON (MICRON)
- Medication recommendation with DDI constraints
- Interaction-aware predictions
- Safety-focused drug selection
SafeDrug (SafeDrug)
- Safety-aware drug recommendation
- Molecular structure integration
- DDI constraint optimization
- Balances efficacy and safety
Key Features:
- Molecular graph encoding
- DDI graph neural network
- Reinforcement learning for safety
- Published in KDD 2021
Usage:
from pyhealth.models import SafeDrug
model = SafeDrug(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="multilabel",
ddi_adj_path="/path/to/ddi_matrix.pkl",
molecule_path="/path/to/molecule_graphs.pkl"
)
MoleRec (MoleRec)
- Molecular-level drug recommendations
- Sub-structure reasoning
- Fine-grained medication selection
Disease Progression Models
StageNet (StageNet)
- Disease stage-aware prediction
- Learns clinical stages automatically
- Stage-adaptive feature extraction
- Effective for chronic disease monitoring
Architecture:
- Stage-aware LSTM
- Dynamic stage transitions
- Time-decay mechanism
Usage:
from pyhealth.models import StageNet
model = StageNet(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications"],
mode="binary",
hidden_dim=128,
num_stages=3,
chunk_size=128
)
Best for:
- ICU mortality prediction
- Chronic disease progression
- Time-varying risk assessment
Deepr (Deepr)
- Deep recurrent architecture
- Medical concept embeddings
- Temporal pattern learning
- Published in JAMIA
Advanced Sequential Models
Agent (Agent)
- Reinforcement learning-based
- Treatment recommendation
- Action-value optimization
- Policy learning for sequential decisions
GRASP (GRASP)
- Graph-based sequence patterns
- Structural event relationships
- Hierarchical representation learning
SparcNet (SparcNet)
- Sparse clinical networks
- Efficient feature selection
- Reduced computational cost
- Interpretable predictions
ContraWR (ContraWR)
- Contrastive learning approach
- Self-supervised pre-training
- Robust representations
- Limited labeled data scenarios
Medical Entity Linking
MedLink (MedLink)
- Medical entity linking to knowledge bases
- Clinical concept normalization
- UMLS integration
- Entity disambiguation
Generative Models
GAN (GAN)
- Generative Adversarial Networks
- Synthetic EHR data generation
- Privacy-preserving data sharing
- Augmentation for rare conditions
VAE (VAE)
- Variational Autoencoder
- Patient representation learning
- Anomaly detection
- Latent space exploration
Social Determinants of Health
SDOH (SDOH)
- Social determinants integration
- Multi-modal prediction
- Addresses health disparities
- Combines clinical and social data
Model Selection Guidelines
By Task Type
Binary Classification (Mortality, Readmission)
- Start with: Logistic Regression (baseline)
- Standard: RNN, Transformer
- Interpretable: RETAIN, AdaCare
- Advanced: StageNet
Multi-Label Classification (Drug Recommendation)
- Standard: CNN, RNN
- Healthcare-specific: GAMENet, SafeDrug, MICRON, MoleRec
- Graph-based: GNN
Regression (Length of Stay)
- Start with: MLP (baseline)
- Sequential: RNN, TCN
- Advanced: Transformer
Multi-Class Classification (Medical Coding, Specialty)
- Standard: CNN, RNN, Transformer
- Text-based: TransformersModel (BERT variants)
By Data Type
Sequential Events (Diagnoses, Medications, Procedures)
- RNN, LSTM, GRU
- Transformer
- RETAIN, AdaCare, ConCare
Time-Series Signals (EEG, ECG)
- CNN, TCN
- RNN
- Transformer
Text (Clinical Notes)
- TransformersModel (ClinicalBERT, BioBERT)
- CNN for shorter text
- RNN for sequential text
Graphs (Drug Interactions, Patient Networks)
- GNN (GAT, GCN)
- GAMENet, SafeDrug
Images (X-rays, CT scans)
- CNN (ResNet, DenseNet via TransformersModel)
- Vision Transformers
By Interpretability Needs
High Interpretability Required:
- Logistic Regression
- RETAIN
- AdaCare
- SparcNet
Moderate Interpretability:
- CNN (filter visualization)
- Transformer (attention visualization)
- GNN (graph attention)
Black-Box Acceptable:
- Deep RNN models
- Complex ensembles
Training Considerations
Hyperparameter Tuning
Embedding Dimension:
- Small datasets: 64-128
- Large datasets: 128-256
- Complex tasks: 256-512
Hidden Dimension:
- Proportional to embedding_dim
- Typically 1-2x embedding_dim
Number of Layers:
- Start with 2-3 layers
- Deeper for complex patterns
- Watch for overfitting
Dropout:
- Start with 0.5
- Reduce if underfitting (0.1-0.3)
- Increase if overfitting (0.5-0.7)
Computational Requirements
Memory (GPU):
- CNN: Low to moderate
- RNN: Moderate (sequence length dependent)
- Transformer: High (quadratic in sequence length)
- GNN: Moderate to high (graph size dependent)
Training Speed:
- Fastest: Logistic Regression, MLP, CNN
- Moderate: RNN, GNN
- Slower: Transformer (but parallelizable)
Best Practices
- Start with simple baselines (Logistic Regression, MLP)
- Use appropriate feature keys based on data availability
- Match mode to task output (binary, multiclass, multilabel, regression)
- Consider interpretability requirements for clinical deployment
- Validate on held-out test set for realistic performance
- Monitor for overfitting especially with complex models
- Use pretrained models when possible (TransformersModel)
- Consider computational constraints for deployment
Example Workflow
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer
# 1. Prepare data
dataset = MIMIC4Dataset(root="/path/to/data")
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
# 2. Initialize model
model = Transformer(
dataset=sample_dataset,
feature_keys=["diagnoses", "medications", "procedures"],
mode="binary",
embedding_dim=128,
num_heads=8,
num_layers=3,
dropout=0.3
)
# 3. Train model
trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=50,
monitor="pr_auc_score",
monitor_criterion="max"
)
# 4. Evaluate
results = trainer.evaluate(test_loader)
print(results)