Files
2025-11-30 08:30:10 +08:00

565 lines
15 KiB
Markdown

# Callbacks - Comprehensive Guide
## Overview
Callbacks enable adding arbitrary self-contained programs to training without cluttering your LightningModule research code. They execute custom logic at specific hooks during the training lifecycle.
## Architecture
Lightning organizes training logic across three components:
- **Trainer** - Engineering infrastructure
- **LightningModule** - Research code
- **Callbacks** - Non-essential functionality (monitoring, checkpointing, custom behaviors)
## Creating Custom Callbacks
Basic structure:
```python
from lightning.pytorch.callbacks import Callback
class MyCustomCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting!")
def on_train_end(self, trainer, pl_module):
print("Training is done!")
# Use with Trainer
trainer = L.Trainer(callbacks=[MyCustomCallback()])
```
## Built-in Callbacks
### ModelCheckpoint
Save models based on monitored metrics.
**Key Parameters:**
- `dirpath` - Directory to save checkpoints
- `filename` - Checkpoint filename pattern
- `monitor` - Metric to monitor
- `mode` - "min" or "max" for monitored metric
- `save_top_k` - Number of best models to keep
- `save_last` - Save last epoch checkpoint
- `every_n_epochs` - Save every N epochs
- `save_on_train_epoch_end` - Save at train epoch end vs validation end
**Examples:**
```python
from lightning.pytorch.callbacks import ModelCheckpoint
# Save top 3 models based on validation loss
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="model-{epoch:02d}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3,
save_last=True
)
# Save every 10 epochs
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="model-{epoch:02d}",
every_n_epochs=10,
save_top_k=-1 # Save all
)
# Save best model based on accuracy
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="best-model",
monitor="val_acc",
mode="max",
save_top_k=1
)
trainer = L.Trainer(callbacks=[checkpoint_callback])
```
**Accessing Saved Checkpoints:**
```python
# Get best model path
best_model_path = checkpoint_callback.best_model_path
# Get last checkpoint path
last_checkpoint = checkpoint_callback.last_model_path
# Get all checkpoint paths
all_checkpoints = checkpoint_callback.best_k_models
```
### EarlyStopping
Stop training when a monitored metric stops improving.
**Key Parameters:**
- `monitor` - Metric to monitor
- `patience` - Number of epochs with no improvement after which training stops
- `mode` - "min" or "max" for monitored metric
- `min_delta` - Minimum change to qualify as an improvement
- `verbose` - Print messages
- `strict` - Crash if monitored metric not found
**Examples:**
```python
from lightning.pytorch.callbacks import EarlyStopping
# Stop when validation loss stops improving
early_stop = EarlyStopping(
monitor="val_loss",
patience=10,
mode="min",
verbose=True
)
# Stop when accuracy plateaus
early_stop = EarlyStopping(
monitor="val_acc",
patience=5,
mode="max",
min_delta=0.001 # Must improve by at least 0.001
)
trainer = L.Trainer(callbacks=[early_stop])
```
### LearningRateMonitor
Track learning rate changes from schedulers.
**Key Parameters:**
- `logging_interval` - When to log: "step" or "epoch"
- `log_momentum` - Also log momentum values
**Example:**
```python
from lightning.pytorch.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = L.Trainer(callbacks=[lr_monitor])
# Logs learning rate automatically as "lr-{optimizer_name}"
```
### DeviceStatsMonitor
Log device performance metrics (GPU/CPU/TPU).
**Key Parameters:**
- `cpu_stats` - Log CPU stats
**Example:**
```python
from lightning.pytorch.callbacks import DeviceStatsMonitor
device_stats = DeviceStatsMonitor(cpu_stats=True)
trainer = L.Trainer(callbacks=[device_stats])
# Logs: gpu_utilization, gpu_memory_usage, etc.
```
### ModelSummary / RichModelSummary
Display model architecture and parameter count.
**Example:**
```python
from lightning.pytorch.callbacks import ModelSummary, RichModelSummary
# Basic summary
summary = ModelSummary(max_depth=2)
# Rich formatted summary (prettier)
rich_summary = RichModelSummary(max_depth=3)
trainer = L.Trainer(callbacks=[rich_summary])
```
### Timer
Track and limit training duration.
**Key Parameters:**
- `duration` - Maximum training time (timedelta or dict)
- `interval` - Check interval: "step", "epoch", or "batch"
**Example:**
```python
from lightning.pytorch.callbacks import Timer
from datetime import timedelta
# Limit training to 1 hour
timer = Timer(duration=timedelta(hours=1))
# Or using dict
timer = Timer(duration={"hours": 23, "minutes": 30})
trainer = L.Trainer(callbacks=[timer])
```
### BatchSizeFinder
Automatically find the optimal batch size.
**Example:**
```python
from lightning.pytorch.callbacks import BatchSizeFinder
batch_finder = BatchSizeFinder(mode="power", steps_per_trial=3)
trainer = L.Trainer(callbacks=[batch_finder])
trainer.fit(model, datamodule=dm)
# Optimal batch size is set automatically
```
### GradientAccumulationScheduler
Schedule gradient accumulation steps dynamically.
**Example:**
```python
from lightning.pytorch.callbacks import GradientAccumulationScheduler
# Accumulate 4 batches for first 5 epochs, then 2 batches
accumulator = GradientAccumulationScheduler(scheduling={0: 4, 5: 2})
trainer = L.Trainer(callbacks=[accumulator])
```
### StochasticWeightAveraging (SWA)
Apply stochastic weight averaging for better generalization.
**Example:**
```python
from lightning.pytorch.callbacks import StochasticWeightAveraging
swa = StochasticWeightAveraging(swa_lrs=1e-2, swa_epoch_start=0.8)
trainer = L.Trainer(callbacks=[swa])
```
## Custom Callback Examples
### Simple Logging Callback
```python
class MetricsLogger(Callback):
def __init__(self):
self.metrics = []
def on_validation_end(self, trainer, pl_module):
# Access logged metrics
metrics = trainer.callback_metrics
self.metrics.append(dict(metrics))
print(f"Validation metrics: {metrics}")
```
### Gradient Monitoring Callback
```python
class GradientMonitor(Callback):
def on_after_backward(self, trainer, pl_module):
# Log gradient norms
for name, param in pl_module.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
pl_module.log(f"grad_norm/{name}", grad_norm)
```
### Custom Checkpointing Callback
```python
class CustomCheckpoint(Callback):
def __init__(self, save_dir):
self.save_dir = save_dir
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
if epoch % 5 == 0: # Save every 5 epochs
filepath = f"{self.save_dir}/custom-{epoch}.ckpt"
trainer.save_checkpoint(filepath)
print(f"Saved checkpoint: {filepath}")
```
### Model Freezing Callback
```python
class FreezeUnfreeze(Callback):
def __init__(self, freeze_until_epoch=10):
self.freeze_until_epoch = freeze_until_epoch
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
if epoch < self.freeze_until_epoch:
# Freeze backbone
for param in pl_module.backbone.parameters():
param.requires_grad = False
else:
# Unfreeze backbone
for param in pl_module.backbone.parameters():
param.requires_grad = True
```
### Learning Rate Finder Callback
```python
class LRFinder(Callback):
def __init__(self, min_lr=1e-5, max_lr=1e-1, num_steps=100):
self.min_lr = min_lr
self.max_lr = max_lr
self.num_steps = num_steps
self.lrs = []
self.losses = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx >= self.num_steps:
trainer.should_stop = True
return
# Exponential LR schedule
lr = self.min_lr * (self.max_lr / self.min_lr) ** (batch_idx / self.num_steps)
optimizer = trainer.optimizers[0]
for param_group in optimizer.param_groups:
param_group['lr'] = lr
self.lrs.append(lr)
self.losses.append(outputs['loss'].item())
def on_train_end(self, trainer, pl_module):
# Plot LR vs Loss
import matplotlib.pyplot as plt
plt.plot(self.lrs, self.losses)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.savefig('lr_finder.png')
```
### Prediction Saver Callback
```python
class PredictionSaver(Callback):
def __init__(self, save_path):
self.save_path = save_path
self.predictions = []
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.predictions.append(outputs)
def on_predict_end(self, trainer, pl_module):
# Save all predictions
torch.save(self.predictions, self.save_path)
print(f"Predictions saved to {self.save_path}")
```
## Available Hooks
### Setup and Teardown
- `setup(trainer, pl_module, stage)` - Called at beginning of fit/test/predict
- `teardown(trainer, pl_module, stage)` - Called at end of fit/test/predict
### Training Lifecycle
- `on_fit_start(trainer, pl_module)` - Called at start of fit
- `on_fit_end(trainer, pl_module)` - Called at end of fit
- `on_train_start(trainer, pl_module)` - Called at start of training
- `on_train_end(trainer, pl_module)` - Called at end of training
### Epoch Boundaries
- `on_train_epoch_start(trainer, pl_module)` - Called at start of training epoch
- `on_train_epoch_end(trainer, pl_module)` - Called at end of training epoch
- `on_validation_epoch_start(trainer, pl_module)` - Called at start of validation
- `on_validation_epoch_end(trainer, pl_module)` - Called at end of validation
- `on_test_epoch_start(trainer, pl_module)` - Called at start of test
- `on_test_epoch_end(trainer, pl_module)` - Called at end of test
### Batch Boundaries
- `on_train_batch_start(trainer, pl_module, batch, batch_idx)` - Before training batch
- `on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)` - After training batch
- `on_validation_batch_start(trainer, pl_module, batch, batch_idx)` - Before validation batch
- `on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx)` - After validation batch
### Gradient Events
- `on_before_backward(trainer, pl_module, loss)` - Before loss.backward()
- `on_after_backward(trainer, pl_module)` - After loss.backward()
- `on_before_optimizer_step(trainer, pl_module, optimizer)` - Before optimizer.step()
### Checkpoint Events
- `on_save_checkpoint(trainer, pl_module, checkpoint)` - When saving checkpoint
- `on_load_checkpoint(trainer, pl_module, checkpoint)` - When loading checkpoint
### Exception Handling
- `on_exception(trainer, pl_module, exception)` - When exception occurs
## State Management
For callbacks requiring persistence across checkpoints:
```python
class StatefulCallback(Callback):
def __init__(self):
self.counter = 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.counter += 1
def state_dict(self):
return {"counter": self.counter}
def load_state_dict(self, state_dict):
self.counter = state_dict["counter"]
@property
def state_key(self):
# Unique identifier for this callback
return "my_stateful_callback"
```
## Best Practices
### 1. Keep Callbacks Isolated
Each callback should be self-contained and independent:
```python
# Good: Self-contained
class MyCallback(Callback):
def __init__(self):
self.data = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.data.append(outputs['loss'].item())
# Bad: Depends on external state
global_data = []
class BadCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
global_data.append(outputs['loss'].item()) # External dependency
```
### 2. Avoid Inter-Callback Dependencies
Callbacks should not depend on other callbacks:
```python
# Bad: Callback B depends on Callback A
class CallbackA(Callback):
def __init__(self):
self.value = 0
class CallbackB(Callback):
def __init__(self, callback_a):
self.callback_a = callback_a # Tight coupling
# Good: Independent callbacks
class CallbackA(Callback):
def __init__(self):
self.value = 0
class CallbackB(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Access trainer state instead
value = trainer.callback_metrics.get('metric')
```
### 3. Never Manually Invoke Callback Methods
Let Lightning call callbacks automatically:
```python
# Bad: Manual invocation
callback = MyCallback()
callback.on_train_start(trainer, model) # Don't do this
# Good: Let Trainer handle it
trainer = L.Trainer(callbacks=[MyCallback()])
```
### 4. Design for Any Execution Order
Callbacks may execute in any order, so don't rely on specific ordering:
```python
# Good: Order-independent
class GoodCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Use trainer state, not other callbacks
metrics = trainer.callback_metrics
self.log_metrics(metrics)
```
### 5. Use Callbacks for Non-Essential Logic
Keep core research code in LightningModule, use callbacks for auxiliary functionality:
```python
# Good separation
class MyModel(L.LightningModule):
# Core research logic here
def training_step(self, batch, batch_idx):
return loss
# Non-essential monitoring in callback
class MonitorCallback(Callback):
def on_validation_end(self, trainer, pl_module):
# Monitoring logic
pass
```
## Common Patterns
### Combining Multiple Callbacks
```python
from lightning.pytorch.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor,
DeviceStatsMonitor
)
callbacks = [
ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=3),
EarlyStopping(monitor="val_loss", patience=10, mode="min"),
LearningRateMonitor(logging_interval="step"),
DeviceStatsMonitor()
]
trainer = L.Trainer(callbacks=callbacks)
```
### Conditional Callback Activation
```python
class ConditionalCallback(Callback):
def __init__(self, activate_after_epoch=10):
self.activate_after_epoch = activate_after_epoch
def on_train_epoch_end(self, trainer, pl_module):
if trainer.current_epoch >= self.activate_after_epoch:
# Only active after specified epoch
self.do_something(trainer, pl_module)
```
### Multi-Stage Training Callback
```python
class MultiStageTraining(Callback):
def __init__(self, stage_epochs=[10, 20, 30]):
self.stage_epochs = stage_epochs
self.current_stage = 0
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
if epoch in self.stage_epochs:
self.current_stage += 1
print(f"Entering stage {self.current_stage}")
# Adjust learning rate for new stage
for optimizer in trainer.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
```