# Machine Learning ## Overview PathML provides comprehensive machine learning capabilities for computational pathology, including pre-built models for nucleus detection and segmentation, PyTorch-integrated training workflows, public dataset access, and ONNX-based inference deployment. The framework seamlessly bridges image preprocessing with deep learning to enable end-to-end pathology ML pipelines. ## Pre-Built Models PathML includes state-of-the-art pre-trained models for nucleus analysis: ### HoVer-Net **HoVer-Net** (Horizontal and Vertical Network) performs simultaneous nucleus instance segmentation and classification. **Architecture:** - Encoder-decoder structure with three prediction branches: - **Nuclear Pixel (NP)** - Binary segmentation of nuclear regions - **Horizontal-Vertical (HV)** - Distance maps to nucleus centroids - **Classification (NC)** - Nucleus type classification **Nucleus types:** 1. Epithelial 2. Inflammatory 3. Connective/Soft tissue 4. Dead/Necrotic 5. Background **Usage:** ```python from pathml.ml import HoVerNet import torch # Load pre-trained model model = HoVerNet( num_types=5, # Number of nucleus types mode='fast', # 'fast' or 'original' pretrained=True # Load pre-trained weights ) # Move to GPU if available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # Inference on tile tile_image = torch.from_numpy(tile.image).permute(2, 0, 1).unsqueeze(0).float() tile_image = tile_image.to(device) with torch.no_grad(): output = model(tile_image) # Output contains: # - output['np']: Nuclear pixel predictions # - output['hv']: Horizontal-vertical maps # - output['nc']: Classification predictions ``` **Post-processing:** ```python from pathml.ml import hovernet_postprocess # Convert model outputs to instance segmentation instance_map, type_map = hovernet_postprocess( np_pred=output['np'], hv_pred=output['hv'], nc_pred=output['nc'] ) # instance_map: Each nucleus has unique ID # type_map: Each nucleus assigned a type (1-5) ``` ### HACTNet **HACTNet** (Hierarchical Cell-Type Network) performs hierarchical nucleus classification with uncertainty quantification. **Features:** - Hierarchical classification (coarse to fine-grained types) - Uncertainty estimation for predictions - Improved performance on imbalanced datasets ```python from pathml.ml import HACTNet # Load model model = HACTNet( num_classes_coarse=3, num_classes_fine=8, pretrained=True ) # Inference output = model(tile_image) coarse_pred = output['coarse'] # Broad categories fine_pred = output['fine'] # Specific cell types uncertainty = output['uncertainty'] # Prediction confidence ``` ## Training Workflows ### Dataset Preparation PathML provides PyTorch-compatible dataset classes: **TileDataset:** ```python from pathml.ml import TileDataset from pathml.core import SlideDataset # Create dataset from processed slides tile_dataset = TileDataset( slide_dataset, tile_size=256, transform=None # Optional augmentation transforms ) # Access tiles image, label = tile_dataset[0] ``` **DataModule Integration:** ```python from pathml.ml import PathMLDataModule # Create train/val/test splits data_module = PathMLDataModule( train_dataset=train_tile_dataset, val_dataset=val_tile_dataset, test_dataset=test_tile_dataset, batch_size=32, num_workers=4 ) # Use with PyTorch Lightning trainer = pl.Trainer(max_epochs=100) trainer.fit(model, data_module) ``` ### Training HoVer-Net Complete workflow for training HoVer-Net on custom data: ```python import torch import torch.nn as nn from torch.utils.data import DataLoader from pathml.ml import HoVerNet from pathml.ml.datasets import PanNukeDataModule # 1. Prepare data data_module = PanNukeDataModule( data_dir='path/to/pannuke', batch_size=8, num_workers=4, tissue_types=['Breast', 'Colon'] # Specific tissue types ) # 2. Initialize model model = HoVerNet( num_types=5, mode='fast', pretrained=False # Train from scratch or use pretrained=True for fine-tuning ) # 3. Define loss function class HoVerNetLoss(nn.Module): def __init__(self): super().__init__() self.mse_loss = nn.MSELoss() self.bce_loss = nn.BCEWithLogitsLoss() self.ce_loss = nn.CrossEntropyLoss() def forward(self, output, target): # Nuclear pixel branch loss np_loss = self.bce_loss(output['np'], target['np']) # Horizontal-vertical branch loss hv_loss = self.mse_loss(output['hv'], target['hv']) # Classification branch loss nc_loss = self.ce_loss(output['nc'], target['nc']) # Combined loss total_loss = np_loss + hv_loss + 2.0 * nc_loss return total_loss, {'np': np_loss, 'hv': hv_loss, 'nc': nc_loss} criterion = HoVerNetLoss() # 4. Configure optimizer optimizer = torch.optim.Adam( model.parameters(), lr=1e-4, weight_decay=1e-5 ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=10 ) # 5. Training loop device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) num_epochs = 100 for epoch in range(num_epochs): model.train() train_loss = 0.0 for batch in data_module.train_dataloader(): images = batch['image'].to(device) targets = { 'np': batch['np_map'].to(device), 'hv': batch['hv_map'].to(device), 'nc': batch['type_map'].to(device) } optimizer.zero_grad() outputs = model(images) loss, loss_dict = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for batch in data_module.val_dataloader(): images = batch['image'].to(device) targets = { 'np': batch['np_map'].to(device), 'hv': batch['hv_map'].to(device), 'nc': batch['type_map'].to(device) } outputs = model(images) loss, _ = criterion(outputs, targets) val_loss += loss.item() scheduler.step(val_loss) print(f"Epoch {epoch+1}/{num_epochs}") print(f" Train Loss: {train_loss/len(data_module.train_dataloader()):.4f}") print(f" Val Loss: {val_loss/len(data_module.val_dataloader()):.4f}") # Save checkpoint if (epoch + 1) % 10 == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': val_loss, }, f'hovernet_checkpoint_epoch_{epoch+1}.pth') ``` ### PyTorch Lightning Integration PathML models integrate with PyTorch Lightning for streamlined training: ```python import pytorch_lightning as pl from pathml.ml import HoVerNet from pathml.ml.datasets import PanNukeDataModule class HoVerNetModule(pl.LightningModule): def __init__(self, num_types=5, lr=1e-4): super().__init__() self.model = HoVerNet(num_types=num_types, pretrained=True) self.lr = lr self.criterion = HoVerNetLoss() def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): images = batch['image'] targets = { 'np': batch['np_map'], 'hv': batch['hv_map'], 'nc': batch['type_map'] } outputs = self(images) loss, loss_dict = self.criterion(outputs, targets) # Log metrics self.log('train_loss', loss, prog_bar=True) for key, val in loss_dict.items(): self.log(f'train_{key}_loss', val) return loss def validation_step(self, batch, batch_idx): images = batch['image'] targets = { 'np': batch['np_map'], 'hv': batch['hv_map'], 'nc': batch['type_map'] } outputs = self(images) loss, loss_dict = self.criterion(outputs, targets) self.log('val_loss', loss, prog_bar=True) for key, val in loss_dict.items(): self.log(f'val_{key}_loss', val) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=10 ) return { 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': scheduler, 'monitor': 'val_loss' } } # Train with PyTorch Lightning data_module = PanNukeDataModule(data_dir='path/to/pannuke', batch_size=8) model = HoVerNetModule(num_types=5, lr=1e-4) trainer = pl.Trainer( max_epochs=100, accelerator='gpu', devices=1, callbacks=[ pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min'), pl.callbacks.EarlyStopping(monitor='val_loss', patience=20) ] ) trainer.fit(model, data_module) ``` ## Public Datasets PathML provides convenient access to public pathology datasets: ### PanNuke Dataset **PanNuke** contains 7,901 histology image patches from 19 tissue types with nucleus annotations for 5 cell types. ```python from pathml.ml.datasets import PanNukeDataModule # Load PanNuke dataset pannuke = PanNukeDataModule( data_dir='path/to/pannuke', batch_size=16, num_workers=4, tissue_types=None, # Use all tissue types, or specify list fold='all' # 'fold1', 'fold2', 'fold3', or 'all' ) # Access dataloaders train_loader = pannuke.train_dataloader() val_loader = pannuke.val_dataloader() test_loader = pannuke.test_dataloader() # Batch structure for batch in train_loader: images = batch['image'] # Shape: (B, 3, 256, 256) inst_map = batch['inst_map'] # Instance segmentation map type_map = batch['type_map'] # Cell type map np_map = batch['np_map'] # Nuclear pixel map hv_map = batch['hv_map'] # Horizontal-vertical distance maps tissue_type = batch['tissue_type'] # Tissue category ``` **Tissue types available:** Breast, Colon, Prostate, Lung, Kidney, Stomach, Bladder, Esophagus, Cervix, Liver, Thyroid, Head & Neck, Testis, Adrenal, Pancreas, Bile Duct, Ovary, Skin, Uterus ### TCGA Datasets Access The Cancer Genome Atlas datasets: ```python from pathml.ml.datasets import TCGADataModule # Load TCGA dataset tcga = TCGADataModule( data_dir='path/to/tcga', cancer_type='BRCA', # Breast cancer batch_size=32, tile_size=224 ) ``` ### Custom Dataset Integration Create custom datasets for PathML workflows: ```python from torch.utils.data import Dataset import numpy as np from pathlib import Path class CustomPathologyDataset(Dataset): def __init__(self, data_dir, transform=None): self.data_dir = Path(data_dir) self.image_paths = list(self.data_dir.glob('images/*.png')) self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # Load image image_path = self.image_paths[idx] image = np.array(Image.open(image_path)) # Load corresponding annotation annot_path = self.data_dir / 'annotations' / f'{image_path.stem}.npy' annotation = np.load(annot_path) # Apply transforms if self.transform: image = self.transform(image) return { 'image': torch.from_numpy(image).permute(2, 0, 1).float(), 'annotation': torch.from_numpy(annotation).long(), 'path': str(image_path) } # Use in PathML workflow dataset = CustomPathologyDataset('path/to/data') dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4) ``` ## Data Augmentation Apply augmentations to improve model generalization: ```python import albumentations as A from albumentations.pytorch import ToTensorV2 # Define augmentation pipeline train_transform = A.Compose([ A.RandomRotate90(p=0.5), A.Flip(p=0.5), A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5), A.GaussianBlur(blur_limit=(3, 7), p=0.3), A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) val_transform = A.Compose([ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) # Apply to dataset train_dataset = TileDataset(slide_dataset, transform=train_transform) val_dataset = TileDataset(val_slide_dataset, transform=val_transform) ``` ## Model Evaluation ### Metrics Evaluate model performance with pathology-specific metrics: ```python from pathml.ml.metrics import ( dice_coefficient, aggregated_jaccard_index, panoptic_quality ) # Dice coefficient for segmentation dice = dice_coefficient(pred_mask, true_mask) # Aggregated Jaccard Index (AJI) for instance segmentation aji = aggregated_jaccard_index(pred_inst, true_inst) # Panoptic Quality (PQ) for joint segmentation and classification pq, sq, rq = panoptic_quality(pred_inst, true_inst, pred_types, true_types) print(f"Dice: {dice:.4f}") print(f"AJI: {aji:.4f}") print(f"PQ: {pq:.4f}, SQ: {sq:.4f}, RQ: {rq:.4f}") ``` ### Evaluation Loop ```python from pathml.ml.metrics import evaluate_hovernet # Comprehensive HoVer-Net evaluation model.eval() all_preds = [] all_targets = [] with torch.no_grad(): for batch in test_loader: images = batch['image'].to(device) outputs = model(images) # Post-process predictions for i in range(len(images)): inst_pred, type_pred = hovernet_postprocess( outputs['np'][i], outputs['hv'][i], outputs['nc'][i] ) all_preds.append({'inst': inst_pred, 'type': type_pred}) all_targets.append({ 'inst': batch['inst_map'][i], 'type': batch['type_map'][i] }) # Compute metrics results = evaluate_hovernet(all_preds, all_targets) print(f"Detection F1: {results['detection_f1']:.4f}") print(f"Classification Accuracy: {results['classification_acc']:.4f}") print(f"Panoptic Quality: {results['pq']:.4f}") ``` ## ONNX Inference Deploy models using ONNX for production inference: ### Export to ONNX ```python import torch from pathml.ml import HoVerNet # Load trained model model = HoVerNet(num_types=5, pretrained=True) model.eval() # Create dummy input dummy_input = torch.randn(1, 3, 256, 256) # Export to ONNX torch.onnx.export( model, dummy_input, 'hovernet_model.onnx', export_params=True, opset_version=11, input_names=['input'], output_names=['np_output', 'hv_output', 'nc_output'], dynamic_axes={ 'input': {0: 'batch_size'}, 'np_output': {0: 'batch_size'}, 'hv_output': {0: 'batch_size'}, 'nc_output': {0: 'batch_size'} } ) ``` ### ONNX Runtime Inference ```python import onnxruntime as ort import numpy as np # Load ONNX model session = ort.InferenceSession('hovernet_model.onnx') # Prepare input input_name = session.get_inputs()[0].name tile_image = preprocess_tile(tile) # Normalize, transpose to (1, 3, H, W) # Run inference outputs = session.run(None, {input_name: tile_image}) np_output, hv_output, nc_output = outputs # Post-process inst_map, type_map = hovernet_postprocess(np_output, hv_output, nc_output) ``` ### Batch Inference Pipeline ```python from pathml.core import SlideData from pathml.preprocessing import Pipeline import onnxruntime as ort def run_onnx_inference_pipeline(slide_path, onnx_model_path): # Load slide wsi = SlideData.from_slide(slide_path) wsi.generate_tiles(level=1, tile_size=256, stride=256) # Load ONNX model session = ort.InferenceSession(onnx_model_path) input_name = session.get_inputs()[0].name # Inference on all tiles results = [] for tile in wsi.tiles: # Preprocess tile_array = preprocess_tile(tile.image) # Inference outputs = session.run(None, {input_name: tile_array}) # Post-process inst_map, type_map = hovernet_postprocess(*outputs) results.append({ 'coords': tile.coords, 'instance_map': inst_map, 'type_map': type_map }) return results # Run on slide results = run_onnx_inference_pipeline('slide.svs', 'hovernet_model.onnx') ``` ## Transfer Learning Fine-tune pre-trained models on custom datasets: ```python from pathml.ml import HoVerNet # Load pre-trained model model = HoVerNet(num_types=5, pretrained=True) # Freeze encoder layers for initial training for name, param in model.named_parameters(): if 'encoder' in name: param.requires_grad = False # Fine-tune only decoder and classification heads optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4 ) # Train for a few epochs train_for_n_epochs(model, train_loader, optimizer, num_epochs=10) # Unfreeze all layers for full fine-tuning for param in model.parameters(): param.requires_grad = True # Continue training with lower learning rate optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) train_for_n_epochs(model, train_loader, optimizer, num_epochs=50) ``` ## Best Practices 1. **Use pre-trained models when available:** - Start with pretrained=True for better initialization - Fine-tune on domain-specific data 2. **Apply appropriate data augmentation:** - Rotate, flip for orientation invariance - Color jitter to handle staining variations - Elastic deformation for biological variability 3. **Monitor multiple metrics:** - Track detection, segmentation, and classification separately - Use domain-specific metrics (AJI, PQ) beyond standard accuracy 4. **Handle class imbalance:** - Weighted loss functions for rare cell types - Oversampling minority classes - Focal loss for hard examples 5. **Validate on diverse tissue types:** - Ensure generalization across different tissues - Test on held-out anatomical sites 6. **Optimize for inference:** - Export to ONNX for faster deployment - Batch tiles for efficient GPU utilization - Use mixed precision (FP16) when possible 7. **Save checkpoints regularly:** - Keep best model based on validation metrics - Save optimizer state for training resumption ## Common Issues and Solutions **Issue: Poor segmentation at nucleus boundaries** - Use HV maps (horizontal-vertical) to separate touching nuclei - Increase weight of HV loss term - Apply morphological post-processing **Issue: Misclassification of similar cell types** - Increase classification loss weight - Add hierarchical classification (HACTNet) - Augment training data for confused classes **Issue: Training unstable or not converging** - Reduce learning rate - Use gradient clipping: `torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)` - Check for data preprocessing issues **Issue: Out of memory during training** - Reduce batch size - Use gradient accumulation - Enable mixed precision training: `torch.cuda.amp` **Issue: Model overfits to training data** - Increase data augmentation - Add dropout layers - Reduce model capacity - Use early stopping based on validation loss ## Additional Resources - **PathML ML API:** https://pathml.readthedocs.io/en/latest/api_ml_reference.html - **HoVer-Net Paper:** Graham et al., "HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images," Medical Image Analysis, 2019 - **PanNuke Dataset:** https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke - **PyTorch Lightning:** https://www.pytorchlightning.ai/ - **ONNX Runtime:** https://onnxruntime.ai/