20 KiB
Machine Learning
Overview
PathML provides comprehensive machine learning capabilities for computational pathology, including pre-built models for nucleus detection and segmentation, PyTorch-integrated training workflows, public dataset access, and ONNX-based inference deployment. The framework seamlessly bridges image preprocessing with deep learning to enable end-to-end pathology ML pipelines.
Pre-Built Models
PathML includes state-of-the-art pre-trained models for nucleus analysis:
HoVer-Net
HoVer-Net (Horizontal and Vertical Network) performs simultaneous nucleus instance segmentation and classification.
Architecture:
- Encoder-decoder structure with three prediction branches:
- Nuclear Pixel (NP) - Binary segmentation of nuclear regions
- Horizontal-Vertical (HV) - Distance maps to nucleus centroids
- Classification (NC) - Nucleus type classification
Nucleus types:
- Epithelial
- Inflammatory
- Connective/Soft tissue
- Dead/Necrotic
- Background
Usage:
from pathml.ml import HoVerNet
import torch
# Load pre-trained model
model = HoVerNet(
num_types=5, # Number of nucleus types
mode='fast', # 'fast' or 'original'
pretrained=True # Load pre-trained weights
)
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Inference on tile
tile_image = torch.from_numpy(tile.image).permute(2, 0, 1).unsqueeze(0).float()
tile_image = tile_image.to(device)
with torch.no_grad():
output = model(tile_image)
# Output contains:
# - output['np']: Nuclear pixel predictions
# - output['hv']: Horizontal-vertical maps
# - output['nc']: Classification predictions
Post-processing:
from pathml.ml import hovernet_postprocess
# Convert model outputs to instance segmentation
instance_map, type_map = hovernet_postprocess(
np_pred=output['np'],
hv_pred=output['hv'],
nc_pred=output['nc']
)
# instance_map: Each nucleus has unique ID
# type_map: Each nucleus assigned a type (1-5)
HACTNet
HACTNet (Hierarchical Cell-Type Network) performs hierarchical nucleus classification with uncertainty quantification.
Features:
- Hierarchical classification (coarse to fine-grained types)
- Uncertainty estimation for predictions
- Improved performance on imbalanced datasets
from pathml.ml import HACTNet
# Load model
model = HACTNet(
num_classes_coarse=3,
num_classes_fine=8,
pretrained=True
)
# Inference
output = model(tile_image)
coarse_pred = output['coarse'] # Broad categories
fine_pred = output['fine'] # Specific cell types
uncertainty = output['uncertainty'] # Prediction confidence
Training Workflows
Dataset Preparation
PathML provides PyTorch-compatible dataset classes:
TileDataset:
from pathml.ml import TileDataset
from pathml.core import SlideDataset
# Create dataset from processed slides
tile_dataset = TileDataset(
slide_dataset,
tile_size=256,
transform=None # Optional augmentation transforms
)
# Access tiles
image, label = tile_dataset[0]
DataModule Integration:
from pathml.ml import PathMLDataModule
# Create train/val/test splits
data_module = PathMLDataModule(
train_dataset=train_tile_dataset,
val_dataset=val_tile_dataset,
test_dataset=test_tile_dataset,
batch_size=32,
num_workers=4
)
# Use with PyTorch Lightning
trainer = pl.Trainer(max_epochs=100)
trainer.fit(model, data_module)
Training HoVer-Net
Complete workflow for training HoVer-Net on custom data:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from pathml.ml import HoVerNet
from pathml.ml.datasets import PanNukeDataModule
# 1. Prepare data
data_module = PanNukeDataModule(
data_dir='path/to/pannuke',
batch_size=8,
num_workers=4,
tissue_types=['Breast', 'Colon'] # Specific tissue types
)
# 2. Initialize model
model = HoVerNet(
num_types=5,
mode='fast',
pretrained=False # Train from scratch or use pretrained=True for fine-tuning
)
# 3. Define loss function
class HoVerNetLoss(nn.Module):
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss()
self.bce_loss = nn.BCEWithLogitsLoss()
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, output, target):
# Nuclear pixel branch loss
np_loss = self.bce_loss(output['np'], target['np'])
# Horizontal-vertical branch loss
hv_loss = self.mse_loss(output['hv'], target['hv'])
# Classification branch loss
nc_loss = self.ce_loss(output['nc'], target['nc'])
# Combined loss
total_loss = np_loss + hv_loss + 2.0 * nc_loss
return total_loss, {'np': np_loss, 'hv': hv_loss, 'nc': nc_loss}
criterion = HoVerNetLoss()
# 4. Configure optimizer
optimizer = torch.optim.Adam(
model.parameters(),
lr=1e-4,
weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=10
)
# 5. Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
num_epochs = 100
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for batch in data_module.train_dataloader():
images = batch['image'].to(device)
targets = {
'np': batch['np_map'].to(device),
'hv': batch['hv_map'].to(device),
'nc': batch['type_map'].to(device)
}
optimizer.zero_grad()
outputs = model(images)
loss, loss_dict = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in data_module.val_dataloader():
images = batch['image'].to(device)
targets = {
'np': batch['np_map'].to(device),
'hv': batch['hv_map'].to(device),
'nc': batch['type_map'].to(device)
}
outputs = model(images)
loss, _ = criterion(outputs, targets)
val_loss += loss.item()
scheduler.step(val_loss)
print(f"Epoch {epoch+1}/{num_epochs}")
print(f" Train Loss: {train_loss/len(data_module.train_dataloader()):.4f}")
print(f" Val Loss: {val_loss/len(data_module.val_dataloader()):.4f}")
# Save checkpoint
if (epoch + 1) % 10 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': val_loss,
}, f'hovernet_checkpoint_epoch_{epoch+1}.pth')
PyTorch Lightning Integration
PathML models integrate with PyTorch Lightning for streamlined training:
import pytorch_lightning as pl
from pathml.ml import HoVerNet
from pathml.ml.datasets import PanNukeDataModule
class HoVerNetModule(pl.LightningModule):
def __init__(self, num_types=5, lr=1e-4):
super().__init__()
self.model = HoVerNet(num_types=num_types, pretrained=True)
self.lr = lr
self.criterion = HoVerNetLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
images = batch['image']
targets = {
'np': batch['np_map'],
'hv': batch['hv_map'],
'nc': batch['type_map']
}
outputs = self(images)
loss, loss_dict = self.criterion(outputs, targets)
# Log metrics
self.log('train_loss', loss, prog_bar=True)
for key, val in loss_dict.items():
self.log(f'train_{key}_loss', val)
return loss
def validation_step(self, batch, batch_idx):
images = batch['image']
targets = {
'np': batch['np_map'],
'hv': batch['hv_map'],
'nc': batch['type_map']
}
outputs = self(images)
loss, loss_dict = self.criterion(outputs, targets)
self.log('val_loss', loss, prog_bar=True)
for key, val in loss_dict.items():
self.log(f'val_{key}_loss', val)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=10
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss'
}
}
# Train with PyTorch Lightning
data_module = PanNukeDataModule(data_dir='path/to/pannuke', batch_size=8)
model = HoVerNetModule(num_types=5, lr=1e-4)
trainer = pl.Trainer(
max_epochs=100,
accelerator='gpu',
devices=1,
callbacks=[
pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min'),
pl.callbacks.EarlyStopping(monitor='val_loss', patience=20)
]
)
trainer.fit(model, data_module)
Public Datasets
PathML provides convenient access to public pathology datasets:
PanNuke Dataset
PanNuke contains 7,901 histology image patches from 19 tissue types with nucleus annotations for 5 cell types.
from pathml.ml.datasets import PanNukeDataModule
# Load PanNuke dataset
pannuke = PanNukeDataModule(
data_dir='path/to/pannuke',
batch_size=16,
num_workers=4,
tissue_types=None, # Use all tissue types, or specify list
fold='all' # 'fold1', 'fold2', 'fold3', or 'all'
)
# Access dataloaders
train_loader = pannuke.train_dataloader()
val_loader = pannuke.val_dataloader()
test_loader = pannuke.test_dataloader()
# Batch structure
for batch in train_loader:
images = batch['image'] # Shape: (B, 3, 256, 256)
inst_map = batch['inst_map'] # Instance segmentation map
type_map = batch['type_map'] # Cell type map
np_map = batch['np_map'] # Nuclear pixel map
hv_map = batch['hv_map'] # Horizontal-vertical distance maps
tissue_type = batch['tissue_type'] # Tissue category
Tissue types available: Breast, Colon, Prostate, Lung, Kidney, Stomach, Bladder, Esophagus, Cervix, Liver, Thyroid, Head & Neck, Testis, Adrenal, Pancreas, Bile Duct, Ovary, Skin, Uterus
TCGA Datasets
Access The Cancer Genome Atlas datasets:
from pathml.ml.datasets import TCGADataModule
# Load TCGA dataset
tcga = TCGADataModule(
data_dir='path/to/tcga',
cancer_type='BRCA', # Breast cancer
batch_size=32,
tile_size=224
)
Custom Dataset Integration
Create custom datasets for PathML workflows:
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
class CustomPathologyDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = Path(data_dir)
self.image_paths = list(self.data_dir.glob('images/*.png'))
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
image_path = self.image_paths[idx]
image = np.array(Image.open(image_path))
# Load corresponding annotation
annot_path = self.data_dir / 'annotations' / f'{image_path.stem}.npy'
annotation = np.load(annot_path)
# Apply transforms
if self.transform:
image = self.transform(image)
return {
'image': torch.from_numpy(image).permute(2, 0, 1).float(),
'annotation': torch.from_numpy(annotation).long(),
'path': str(image_path)
}
# Use in PathML workflow
dataset = CustomPathologyDataset('path/to/data')
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)
Data Augmentation
Apply augmentations to improve model generalization:
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Define augmentation pipeline
train_transform = A.Compose([
A.RandomRotate90(p=0.5),
A.Flip(p=0.5),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
A.GaussianBlur(blur_limit=(3, 7), p=0.3),
A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
val_transform = A.Compose([
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Apply to dataset
train_dataset = TileDataset(slide_dataset, transform=train_transform)
val_dataset = TileDataset(val_slide_dataset, transform=val_transform)
Model Evaluation
Metrics
Evaluate model performance with pathology-specific metrics:
from pathml.ml.metrics import (
dice_coefficient,
aggregated_jaccard_index,
panoptic_quality
)
# Dice coefficient for segmentation
dice = dice_coefficient(pred_mask, true_mask)
# Aggregated Jaccard Index (AJI) for instance segmentation
aji = aggregated_jaccard_index(pred_inst, true_inst)
# Panoptic Quality (PQ) for joint segmentation and classification
pq, sq, rq = panoptic_quality(pred_inst, true_inst, pred_types, true_types)
print(f"Dice: {dice:.4f}")
print(f"AJI: {aji:.4f}")
print(f"PQ: {pq:.4f}, SQ: {sq:.4f}, RQ: {rq:.4f}")
Evaluation Loop
from pathml.ml.metrics import evaluate_hovernet
# Comprehensive HoVer-Net evaluation
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for batch in test_loader:
images = batch['image'].to(device)
outputs = model(images)
# Post-process predictions
for i in range(len(images)):
inst_pred, type_pred = hovernet_postprocess(
outputs['np'][i],
outputs['hv'][i],
outputs['nc'][i]
)
all_preds.append({'inst': inst_pred, 'type': type_pred})
all_targets.append({
'inst': batch['inst_map'][i],
'type': batch['type_map'][i]
})
# Compute metrics
results = evaluate_hovernet(all_preds, all_targets)
print(f"Detection F1: {results['detection_f1']:.4f}")
print(f"Classification Accuracy: {results['classification_acc']:.4f}")
print(f"Panoptic Quality: {results['pq']:.4f}")
ONNX Inference
Deploy models using ONNX for production inference:
Export to ONNX
import torch
from pathml.ml import HoVerNet
# Load trained model
model = HoVerNet(num_types=5, pretrained=True)
model.eval()
# Create dummy input
dummy_input = torch.randn(1, 3, 256, 256)
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
'hovernet_model.onnx',
export_params=True,
opset_version=11,
input_names=['input'],
output_names=['np_output', 'hv_output', 'nc_output'],
dynamic_axes={
'input': {0: 'batch_size'},
'np_output': {0: 'batch_size'},
'hv_output': {0: 'batch_size'},
'nc_output': {0: 'batch_size'}
}
)
ONNX Runtime Inference
import onnxruntime as ort
import numpy as np
# Load ONNX model
session = ort.InferenceSession('hovernet_model.onnx')
# Prepare input
input_name = session.get_inputs()[0].name
tile_image = preprocess_tile(tile) # Normalize, transpose to (1, 3, H, W)
# Run inference
outputs = session.run(None, {input_name: tile_image})
np_output, hv_output, nc_output = outputs
# Post-process
inst_map, type_map = hovernet_postprocess(np_output, hv_output, nc_output)
Batch Inference Pipeline
from pathml.core import SlideData
from pathml.preprocessing import Pipeline
import onnxruntime as ort
def run_onnx_inference_pipeline(slide_path, onnx_model_path):
# Load slide
wsi = SlideData.from_slide(slide_path)
wsi.generate_tiles(level=1, tile_size=256, stride=256)
# Load ONNX model
session = ort.InferenceSession(onnx_model_path)
input_name = session.get_inputs()[0].name
# Inference on all tiles
results = []
for tile in wsi.tiles:
# Preprocess
tile_array = preprocess_tile(tile.image)
# Inference
outputs = session.run(None, {input_name: tile_array})
# Post-process
inst_map, type_map = hovernet_postprocess(*outputs)
results.append({
'coords': tile.coords,
'instance_map': inst_map,
'type_map': type_map
})
return results
# Run on slide
results = run_onnx_inference_pipeline('slide.svs', 'hovernet_model.onnx')
Transfer Learning
Fine-tune pre-trained models on custom datasets:
from pathml.ml import HoVerNet
# Load pre-trained model
model = HoVerNet(num_types=5, pretrained=True)
# Freeze encoder layers for initial training
for name, param in model.named_parameters():
if 'encoder' in name:
param.requires_grad = False
# Fine-tune only decoder and classification heads
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-4
)
# Train for a few epochs
train_for_n_epochs(model, train_loader, optimizer, num_epochs=10)
# Unfreeze all layers for full fine-tuning
for param in model.parameters():
param.requires_grad = True
# Continue training with lower learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
train_for_n_epochs(model, train_loader, optimizer, num_epochs=50)
Best Practices
-
Use pre-trained models when available:
- Start with pretrained=True for better initialization
- Fine-tune on domain-specific data
-
Apply appropriate data augmentation:
- Rotate, flip for orientation invariance
- Color jitter to handle staining variations
- Elastic deformation for biological variability
-
Monitor multiple metrics:
- Track detection, segmentation, and classification separately
- Use domain-specific metrics (AJI, PQ) beyond standard accuracy
-
Handle class imbalance:
- Weighted loss functions for rare cell types
- Oversampling minority classes
- Focal loss for hard examples
-
Validate on diverse tissue types:
- Ensure generalization across different tissues
- Test on held-out anatomical sites
-
Optimize for inference:
- Export to ONNX for faster deployment
- Batch tiles for efficient GPU utilization
- Use mixed precision (FP16) when possible
-
Save checkpoints regularly:
- Keep best model based on validation metrics
- Save optimizer state for training resumption
Common Issues and Solutions
Issue: Poor segmentation at nucleus boundaries
- Use HV maps (horizontal-vertical) to separate touching nuclei
- Increase weight of HV loss term
- Apply morphological post-processing
Issue: Misclassification of similar cell types
- Increase classification loss weight
- Add hierarchical classification (HACTNet)
- Augment training data for confused classes
Issue: Training unstable or not converging
- Reduce learning rate
- Use gradient clipping:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - Check for data preprocessing issues
Issue: Out of memory during training
- Reduce batch size
- Use gradient accumulation
- Enable mixed precision training:
torch.cuda.amp
Issue: Model overfits to training data
- Increase data augmentation
- Add dropout layers
- Reduce model capacity
- Use early stopping based on validation loss
Additional Resources
- PathML ML API: https://pathml.readthedocs.io/en/latest/api_ml_reference.html
- HoVer-Net Paper: Graham et al., "HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images," Medical Image Analysis, 2019
- PanNuke Dataset: https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke
- PyTorch Lightning: https://www.pytorchlightning.ai/
- ONNX Runtime: https://onnxruntime.ai/