# Callbacks - Comprehensive Guide ## Overview Callbacks enable adding arbitrary self-contained programs to training without cluttering your LightningModule research code. They execute custom logic at specific hooks during the training lifecycle. ## Architecture Lightning organizes training logic across three components: - **Trainer** - Engineering infrastructure - **LightningModule** - Research code - **Callbacks** - Non-essential functionality (monitoring, checkpointing, custom behaviors) ## Creating Custom Callbacks Basic structure: ```python from lightning.pytorch.callbacks import Callback class MyCustomCallback(Callback): def on_train_start(self, trainer, pl_module): print("Training is starting!") def on_train_end(self, trainer, pl_module): print("Training is done!") # Use with Trainer trainer = L.Trainer(callbacks=[MyCustomCallback()]) ``` ## Built-in Callbacks ### ModelCheckpoint Save models based on monitored metrics. **Key Parameters:** - `dirpath` - Directory to save checkpoints - `filename` - Checkpoint filename pattern - `monitor` - Metric to monitor - `mode` - "min" or "max" for monitored metric - `save_top_k` - Number of best models to keep - `save_last` - Save last epoch checkpoint - `every_n_epochs` - Save every N epochs - `save_on_train_epoch_end` - Save at train epoch end vs validation end **Examples:** ```python from lightning.pytorch.callbacks import ModelCheckpoint # Save top 3 models based on validation loss checkpoint_callback = ModelCheckpoint( dirpath="checkpoints/", filename="model-{epoch:02d}-{val_loss:.2f}", monitor="val_loss", mode="min", save_top_k=3, save_last=True ) # Save every 10 epochs checkpoint_callback = ModelCheckpoint( dirpath="checkpoints/", filename="model-{epoch:02d}", every_n_epochs=10, save_top_k=-1 # Save all ) # Save best model based on accuracy checkpoint_callback = ModelCheckpoint( dirpath="checkpoints/", filename="best-model", monitor="val_acc", mode="max", save_top_k=1 ) trainer = L.Trainer(callbacks=[checkpoint_callback]) ``` **Accessing Saved Checkpoints:** ```python # Get best model path best_model_path = checkpoint_callback.best_model_path # Get last checkpoint path last_checkpoint = checkpoint_callback.last_model_path # Get all checkpoint paths all_checkpoints = checkpoint_callback.best_k_models ``` ### EarlyStopping Stop training when a monitored metric stops improving. **Key Parameters:** - `monitor` - Metric to monitor - `patience` - Number of epochs with no improvement after which training stops - `mode` - "min" or "max" for monitored metric - `min_delta` - Minimum change to qualify as an improvement - `verbose` - Print messages - `strict` - Crash if monitored metric not found **Examples:** ```python from lightning.pytorch.callbacks import EarlyStopping # Stop when validation loss stops improving early_stop = EarlyStopping( monitor="val_loss", patience=10, mode="min", verbose=True ) # Stop when accuracy plateaus early_stop = EarlyStopping( monitor="val_acc", patience=5, mode="max", min_delta=0.001 # Must improve by at least 0.001 ) trainer = L.Trainer(callbacks=[early_stop]) ``` ### LearningRateMonitor Track learning rate changes from schedulers. **Key Parameters:** - `logging_interval` - When to log: "step" or "epoch" - `log_momentum` - Also log momentum values **Example:** ```python from lightning.pytorch.callbacks import LearningRateMonitor lr_monitor = LearningRateMonitor(logging_interval="step") trainer = L.Trainer(callbacks=[lr_monitor]) # Logs learning rate automatically as "lr-{optimizer_name}" ``` ### DeviceStatsMonitor Log device performance metrics (GPU/CPU/TPU). **Key Parameters:** - `cpu_stats` - Log CPU stats **Example:** ```python from lightning.pytorch.callbacks import DeviceStatsMonitor device_stats = DeviceStatsMonitor(cpu_stats=True) trainer = L.Trainer(callbacks=[device_stats]) # Logs: gpu_utilization, gpu_memory_usage, etc. ``` ### ModelSummary / RichModelSummary Display model architecture and parameter count. **Example:** ```python from lightning.pytorch.callbacks import ModelSummary, RichModelSummary # Basic summary summary = ModelSummary(max_depth=2) # Rich formatted summary (prettier) rich_summary = RichModelSummary(max_depth=3) trainer = L.Trainer(callbacks=[rich_summary]) ``` ### Timer Track and limit training duration. **Key Parameters:** - `duration` - Maximum training time (timedelta or dict) - `interval` - Check interval: "step", "epoch", or "batch" **Example:** ```python from lightning.pytorch.callbacks import Timer from datetime import timedelta # Limit training to 1 hour timer = Timer(duration=timedelta(hours=1)) # Or using dict timer = Timer(duration={"hours": 23, "minutes": 30}) trainer = L.Trainer(callbacks=[timer]) ``` ### BatchSizeFinder Automatically find the optimal batch size. **Example:** ```python from lightning.pytorch.callbacks import BatchSizeFinder batch_finder = BatchSizeFinder(mode="power", steps_per_trial=3) trainer = L.Trainer(callbacks=[batch_finder]) trainer.fit(model, datamodule=dm) # Optimal batch size is set automatically ``` ### GradientAccumulationScheduler Schedule gradient accumulation steps dynamically. **Example:** ```python from lightning.pytorch.callbacks import GradientAccumulationScheduler # Accumulate 4 batches for first 5 epochs, then 2 batches accumulator = GradientAccumulationScheduler(scheduling={0: 4, 5: 2}) trainer = L.Trainer(callbacks=[accumulator]) ``` ### StochasticWeightAveraging (SWA) Apply stochastic weight averaging for better generalization. **Example:** ```python from lightning.pytorch.callbacks import StochasticWeightAveraging swa = StochasticWeightAveraging(swa_lrs=1e-2, swa_epoch_start=0.8) trainer = L.Trainer(callbacks=[swa]) ``` ## Custom Callback Examples ### Simple Logging Callback ```python class MetricsLogger(Callback): def __init__(self): self.metrics = [] def on_validation_end(self, trainer, pl_module): # Access logged metrics metrics = trainer.callback_metrics self.metrics.append(dict(metrics)) print(f"Validation metrics: {metrics}") ``` ### Gradient Monitoring Callback ```python class GradientMonitor(Callback): def on_after_backward(self, trainer, pl_module): # Log gradient norms for name, param in pl_module.named_parameters(): if param.grad is not None: grad_norm = param.grad.norm().item() pl_module.log(f"grad_norm/{name}", grad_norm) ``` ### Custom Checkpointing Callback ```python class CustomCheckpoint(Callback): def __init__(self, save_dir): self.save_dir = save_dir def on_train_epoch_end(self, trainer, pl_module): epoch = trainer.current_epoch if epoch % 5 == 0: # Save every 5 epochs filepath = f"{self.save_dir}/custom-{epoch}.ckpt" trainer.save_checkpoint(filepath) print(f"Saved checkpoint: {filepath}") ``` ### Model Freezing Callback ```python class FreezeUnfreeze(Callback): def __init__(self, freeze_until_epoch=10): self.freeze_until_epoch = freeze_until_epoch def on_train_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch if epoch < self.freeze_until_epoch: # Freeze backbone for param in pl_module.backbone.parameters(): param.requires_grad = False else: # Unfreeze backbone for param in pl_module.backbone.parameters(): param.requires_grad = True ``` ### Learning Rate Finder Callback ```python class LRFinder(Callback): def __init__(self, min_lr=1e-5, max_lr=1e-1, num_steps=100): self.min_lr = min_lr self.max_lr = max_lr self.num_steps = num_steps self.lrs = [] self.losses = [] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if batch_idx >= self.num_steps: trainer.should_stop = True return # Exponential LR schedule lr = self.min_lr * (self.max_lr / self.min_lr) ** (batch_idx / self.num_steps) optimizer = trainer.optimizers[0] for param_group in optimizer.param_groups: param_group['lr'] = lr self.lrs.append(lr) self.losses.append(outputs['loss'].item()) def on_train_end(self, trainer, pl_module): # Plot LR vs Loss import matplotlib.pyplot as plt plt.plot(self.lrs, self.losses) plt.xscale('log') plt.xlabel('Learning Rate') plt.ylabel('Loss') plt.savefig('lr_finder.png') ``` ### Prediction Saver Callback ```python class PredictionSaver(Callback): def __init__(self, save_path): self.save_path = save_path self.predictions = [] def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.predictions.append(outputs) def on_predict_end(self, trainer, pl_module): # Save all predictions torch.save(self.predictions, self.save_path) print(f"Predictions saved to {self.save_path}") ``` ## Available Hooks ### Setup and Teardown - `setup(trainer, pl_module, stage)` - Called at beginning of fit/test/predict - `teardown(trainer, pl_module, stage)` - Called at end of fit/test/predict ### Training Lifecycle - `on_fit_start(trainer, pl_module)` - Called at start of fit - `on_fit_end(trainer, pl_module)` - Called at end of fit - `on_train_start(trainer, pl_module)` - Called at start of training - `on_train_end(trainer, pl_module)` - Called at end of training ### Epoch Boundaries - `on_train_epoch_start(trainer, pl_module)` - Called at start of training epoch - `on_train_epoch_end(trainer, pl_module)` - Called at end of training epoch - `on_validation_epoch_start(trainer, pl_module)` - Called at start of validation - `on_validation_epoch_end(trainer, pl_module)` - Called at end of validation - `on_test_epoch_start(trainer, pl_module)` - Called at start of test - `on_test_epoch_end(trainer, pl_module)` - Called at end of test ### Batch Boundaries - `on_train_batch_start(trainer, pl_module, batch, batch_idx)` - Before training batch - `on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)` - After training batch - `on_validation_batch_start(trainer, pl_module, batch, batch_idx)` - Before validation batch - `on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx)` - After validation batch ### Gradient Events - `on_before_backward(trainer, pl_module, loss)` - Before loss.backward() - `on_after_backward(trainer, pl_module)` - After loss.backward() - `on_before_optimizer_step(trainer, pl_module, optimizer)` - Before optimizer.step() ### Checkpoint Events - `on_save_checkpoint(trainer, pl_module, checkpoint)` - When saving checkpoint - `on_load_checkpoint(trainer, pl_module, checkpoint)` - When loading checkpoint ### Exception Handling - `on_exception(trainer, pl_module, exception)` - When exception occurs ## State Management For callbacks requiring persistence across checkpoints: ```python class StatefulCallback(Callback): def __init__(self): self.counter = 0 def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.counter += 1 def state_dict(self): return {"counter": self.counter} def load_state_dict(self, state_dict): self.counter = state_dict["counter"] @property def state_key(self): # Unique identifier for this callback return "my_stateful_callback" ``` ## Best Practices ### 1. Keep Callbacks Isolated Each callback should be self-contained and independent: ```python # Good: Self-contained class MyCallback(Callback): def __init__(self): self.data = [] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.data.append(outputs['loss'].item()) # Bad: Depends on external state global_data = [] class BadCallback(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): global_data.append(outputs['loss'].item()) # External dependency ``` ### 2. Avoid Inter-Callback Dependencies Callbacks should not depend on other callbacks: ```python # Bad: Callback B depends on Callback A class CallbackA(Callback): def __init__(self): self.value = 0 class CallbackB(Callback): def __init__(self, callback_a): self.callback_a = callback_a # Tight coupling # Good: Independent callbacks class CallbackA(Callback): def __init__(self): self.value = 0 class CallbackB(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # Access trainer state instead value = trainer.callback_metrics.get('metric') ``` ### 3. Never Manually Invoke Callback Methods Let Lightning call callbacks automatically: ```python # Bad: Manual invocation callback = MyCallback() callback.on_train_start(trainer, model) # Don't do this # Good: Let Trainer handle it trainer = L.Trainer(callbacks=[MyCallback()]) ``` ### 4. Design for Any Execution Order Callbacks may execute in any order, so don't rely on specific ordering: ```python # Good: Order-independent class GoodCallback(Callback): def on_train_epoch_end(self, trainer, pl_module): # Use trainer state, not other callbacks metrics = trainer.callback_metrics self.log_metrics(metrics) ``` ### 5. Use Callbacks for Non-Essential Logic Keep core research code in LightningModule, use callbacks for auxiliary functionality: ```python # Good separation class MyModel(L.LightningModule): # Core research logic here def training_step(self, batch, batch_idx): return loss # Non-essential monitoring in callback class MonitorCallback(Callback): def on_validation_end(self, trainer, pl_module): # Monitoring logic pass ``` ## Common Patterns ### Combining Multiple Callbacks ```python from lightning.pytorch.callbacks import ( ModelCheckpoint, EarlyStopping, LearningRateMonitor, DeviceStatsMonitor ) callbacks = [ ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=3), EarlyStopping(monitor="val_loss", patience=10, mode="min"), LearningRateMonitor(logging_interval="step"), DeviceStatsMonitor() ] trainer = L.Trainer(callbacks=callbacks) ``` ### Conditional Callback Activation ```python class ConditionalCallback(Callback): def __init__(self, activate_after_epoch=10): self.activate_after_epoch = activate_after_epoch def on_train_epoch_end(self, trainer, pl_module): if trainer.current_epoch >= self.activate_after_epoch: # Only active after specified epoch self.do_something(trainer, pl_module) ``` ### Multi-Stage Training Callback ```python class MultiStageTraining(Callback): def __init__(self, stage_epochs=[10, 20, 30]): self.stage_epochs = stage_epochs self.current_stage = 0 def on_train_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch if epoch in self.stage_epochs: self.current_stage += 1 print(f"Entering stage {self.current_stage}") # Adjust learning rate for new stage for optimizer in trainer.optimizers: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 ```