642 lines
14 KiB
Markdown
642 lines
14 KiB
Markdown
# Trainer - Comprehensive Guide
|
|
|
|
## Overview
|
|
|
|
The Trainer automates training workflows after organizing PyTorch code into a LightningModule. It handles loop details, device management, callbacks, gradient operations, checkpointing, and distributed training automatically.
|
|
|
|
## Core Purpose
|
|
|
|
The Trainer manages:
|
|
- Automatically enabling/disabling gradients
|
|
- Running training, validation, and test dataloaders
|
|
- Calling callbacks at appropriate times
|
|
- Placing batches on correct devices
|
|
- Orchestrating distributed training
|
|
- Progress bars and logging
|
|
- Checkpointing and early stopping
|
|
|
|
## Main Methods
|
|
|
|
### `fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None)`
|
|
Runs the full training routine including optional validation.
|
|
|
|
**Parameters:**
|
|
- `model` - LightningModule to train
|
|
- `train_dataloaders` - Training DataLoader(s)
|
|
- `val_dataloaders` - Optional validation DataLoader(s)
|
|
- `datamodule` - Optional LightningDataModule (replaces dataloaders)
|
|
|
|
**Examples:**
|
|
```python
|
|
# With DataLoaders
|
|
trainer = L.Trainer(max_epochs=10)
|
|
trainer.fit(model, train_loader, val_loader)
|
|
|
|
# With DataModule
|
|
trainer.fit(model, datamodule=dm)
|
|
|
|
# Continue training from checkpoint
|
|
trainer.fit(model, train_loader, ckpt_path="checkpoint.ckpt")
|
|
```
|
|
|
|
### `validate(model=None, dataloaders=None, datamodule=None)`
|
|
Run validation loop without training.
|
|
|
|
**Example:**
|
|
```python
|
|
trainer = L.Trainer()
|
|
trainer.validate(model, val_loader)
|
|
```
|
|
|
|
### `test(model=None, dataloaders=None, datamodule=None)`
|
|
Run test loop. Only use before publishing results.
|
|
|
|
**Example:**
|
|
```python
|
|
trainer = L.Trainer()
|
|
trainer.test(model, test_loader)
|
|
```
|
|
|
|
### `predict(model=None, dataloaders=None, datamodule=None)`
|
|
Run inference on data and return predictions.
|
|
|
|
**Example:**
|
|
```python
|
|
trainer = L.Trainer()
|
|
predictions = trainer.predict(model, predict_loader)
|
|
```
|
|
|
|
## Essential Parameters
|
|
|
|
### Training Duration
|
|
|
|
#### `max_epochs` (int)
|
|
Maximum number of epochs to train. Default: 1000
|
|
|
|
```python
|
|
trainer = L.Trainer(max_epochs=100)
|
|
```
|
|
|
|
#### `min_epochs` (int)
|
|
Minimum number of epochs to train. Default: None
|
|
|
|
```python
|
|
trainer = L.Trainer(min_epochs=10, max_epochs=100)
|
|
```
|
|
|
|
#### `max_steps` (int)
|
|
Maximum number of optimizer steps. Overrides max_epochs. Default: -1 (unlimited)
|
|
|
|
```python
|
|
trainer = L.Trainer(max_steps=10000)
|
|
```
|
|
|
|
#### `max_time` (str or dict)
|
|
Maximum training time. Useful for time-limited clusters.
|
|
|
|
```python
|
|
# String format
|
|
trainer = L.Trainer(max_time="00:12:00:00") # 12 hours
|
|
|
|
# Dictionary format
|
|
trainer = L.Trainer(max_time={"days": 1, "hours": 6})
|
|
```
|
|
|
|
### Hardware Configuration
|
|
|
|
#### `accelerator` (str or Accelerator)
|
|
Hardware to use: "cpu", "gpu", "tpu", "ipu", "hpu", "mps", or "auto". Default: "auto"
|
|
|
|
```python
|
|
trainer = L.Trainer(accelerator="gpu")
|
|
trainer = L.Trainer(accelerator="auto") # Auto-detect available hardware
|
|
```
|
|
|
|
#### `devices` (int, list, or str)
|
|
Number or list of device indices to use.
|
|
|
|
```python
|
|
# Use 2 GPUs
|
|
trainer = L.Trainer(devices=2, accelerator="gpu")
|
|
|
|
# Use specific GPUs
|
|
trainer = L.Trainer(devices=[0, 2], accelerator="gpu")
|
|
|
|
# Use all available devices
|
|
trainer = L.Trainer(devices="auto", accelerator="gpu")
|
|
|
|
# CPU with 4 cores
|
|
trainer = L.Trainer(devices=4, accelerator="cpu")
|
|
```
|
|
|
|
#### `strategy` (str or Strategy)
|
|
Distributed training strategy: "ddp", "ddp_spawn", "fsdp", "deepspeed", etc. Default: "auto"
|
|
|
|
```python
|
|
# Data Distributed Parallel
|
|
trainer = L.Trainer(strategy="ddp", accelerator="gpu", devices=4)
|
|
|
|
# Fully Sharded Data Parallel
|
|
trainer = L.Trainer(strategy="fsdp", accelerator="gpu", devices=4)
|
|
|
|
# DeepSpeed
|
|
trainer = L.Trainer(strategy="deepspeed_stage_2", accelerator="gpu", devices=4)
|
|
```
|
|
|
|
#### `precision` (str or int)
|
|
Floating point precision: "32-true", "16-mixed", "bf16-mixed", "64-true", etc.
|
|
|
|
```python
|
|
# Mixed precision (FP16)
|
|
trainer = L.Trainer(precision="16-mixed")
|
|
|
|
# BFloat16 mixed precision
|
|
trainer = L.Trainer(precision="bf16-mixed")
|
|
|
|
# Full precision
|
|
trainer = L.Trainer(precision="32-true")
|
|
```
|
|
|
|
### Optimization Configuration
|
|
|
|
#### `gradient_clip_val` (float)
|
|
Gradient clipping value. Default: None
|
|
|
|
```python
|
|
# Clip gradients by norm
|
|
trainer = L.Trainer(gradient_clip_val=0.5)
|
|
```
|
|
|
|
#### `gradient_clip_algorithm` (str)
|
|
Gradient clipping algorithm: "norm" or "value". Default: "norm"
|
|
|
|
```python
|
|
trainer = L.Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="norm")
|
|
```
|
|
|
|
#### `accumulate_grad_batches` (int or dict)
|
|
Accumulate gradients over N batches before optimizer step.
|
|
|
|
```python
|
|
# Accumulate over 4 batches
|
|
trainer = L.Trainer(accumulate_grad_batches=4)
|
|
|
|
# Different accumulation per epoch
|
|
trainer = L.Trainer(accumulate_grad_batches={0: 4, 5: 2, 10: 1})
|
|
```
|
|
|
|
### Validation Configuration
|
|
|
|
#### `check_val_every_n_epoch` (int)
|
|
Run validation every N epochs. Default: 1
|
|
|
|
```python
|
|
trainer = L.Trainer(check_val_every_n_epoch=10)
|
|
```
|
|
|
|
#### `val_check_interval` (int or float)
|
|
How often to check validation within a training epoch.
|
|
|
|
```python
|
|
# Check validation every 0.25 of training epoch
|
|
trainer = L.Trainer(val_check_interval=0.25)
|
|
|
|
# Check validation every 100 training batches
|
|
trainer = L.Trainer(val_check_interval=100)
|
|
```
|
|
|
|
#### `limit_val_batches` (int or float)
|
|
Limit validation batches.
|
|
|
|
```python
|
|
# Use only 10% of validation data
|
|
trainer = L.Trainer(limit_val_batches=0.1)
|
|
|
|
# Use only 50 validation batches
|
|
trainer = L.Trainer(limit_val_batches=50)
|
|
|
|
# Disable validation
|
|
trainer = L.Trainer(limit_val_batches=0)
|
|
```
|
|
|
|
#### `num_sanity_val_steps` (int)
|
|
Number of validation batches to run before training starts. Default: 2
|
|
|
|
```python
|
|
# Skip sanity check
|
|
trainer = L.Trainer(num_sanity_val_steps=0)
|
|
|
|
# Run 5 sanity validation steps
|
|
trainer = L.Trainer(num_sanity_val_steps=5)
|
|
```
|
|
|
|
### Logging and Progress
|
|
|
|
#### `logger` (Logger or list or bool)
|
|
Logger(s) to use for experiment tracking.
|
|
|
|
```python
|
|
from lightning.pytorch import loggers as pl_loggers
|
|
|
|
# TensorBoard logger
|
|
tb_logger = pl_loggers.TensorBoardLogger("logs/")
|
|
trainer = L.Trainer(logger=tb_logger)
|
|
|
|
# Multiple loggers
|
|
wandb_logger = pl_loggers.WandbLogger(project="my-project")
|
|
trainer = L.Trainer(logger=[tb_logger, wandb_logger])
|
|
|
|
# Disable logging
|
|
trainer = L.Trainer(logger=False)
|
|
```
|
|
|
|
#### `log_every_n_steps` (int)
|
|
How often to log within training steps. Default: 50
|
|
|
|
```python
|
|
trainer = L.Trainer(log_every_n_steps=10)
|
|
```
|
|
|
|
#### `enable_progress_bar` (bool)
|
|
Show progress bar. Default: True
|
|
|
|
```python
|
|
trainer = L.Trainer(enable_progress_bar=False)
|
|
```
|
|
|
|
### Callbacks
|
|
|
|
#### `callbacks` (list)
|
|
List of callbacks to use during training.
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
monitor="val_loss",
|
|
save_top_k=3,
|
|
mode="min"
|
|
)
|
|
|
|
early_stop_callback = EarlyStopping(
|
|
monitor="val_loss",
|
|
patience=5,
|
|
mode="min"
|
|
)
|
|
|
|
trainer = L.Trainer(callbacks=[checkpoint_callback, early_stop_callback])
|
|
```
|
|
|
|
### Checkpointing
|
|
|
|
#### `default_root_dir` (str)
|
|
Default directory for logs and checkpoints. Default: current working directory
|
|
|
|
```python
|
|
trainer = L.Trainer(default_root_dir="./experiments/")
|
|
```
|
|
|
|
#### `enable_checkpointing` (bool)
|
|
Enable automatic checkpointing. Default: True
|
|
|
|
```python
|
|
trainer = L.Trainer(enable_checkpointing=True)
|
|
```
|
|
|
|
### Debugging
|
|
|
|
#### `fast_dev_run` (bool or int)
|
|
Run a single batch (or N batches) through train/val/test for debugging.
|
|
|
|
```python
|
|
# Run 1 batch of train/val/test
|
|
trainer = L.Trainer(fast_dev_run=True)
|
|
|
|
# Run 5 batches of train/val/test
|
|
trainer = L.Trainer(fast_dev_run=5)
|
|
```
|
|
|
|
#### `limit_train_batches` (int or float)
|
|
Limit training batches.
|
|
|
|
```python
|
|
# Use only 25% of training data
|
|
trainer = L.Trainer(limit_train_batches=0.25)
|
|
|
|
# Use only 100 training batches
|
|
trainer = L.Trainer(limit_train_batches=100)
|
|
```
|
|
|
|
#### `limit_test_batches` (int or float)
|
|
Limit test batches.
|
|
|
|
```python
|
|
trainer = L.Trainer(limit_test_batches=0.5)
|
|
```
|
|
|
|
#### `overfit_batches` (int or float)
|
|
Overfit on a subset of data for debugging.
|
|
|
|
```python
|
|
# Overfit on 10 batches
|
|
trainer = L.Trainer(overfit_batches=10)
|
|
|
|
# Overfit on 1% of data
|
|
trainer = L.Trainer(overfit_batches=0.01)
|
|
```
|
|
|
|
#### `detect_anomaly` (bool)
|
|
Enable PyTorch anomaly detection for debugging NaNs. Default: False
|
|
|
|
```python
|
|
trainer = L.Trainer(detect_anomaly=True)
|
|
```
|
|
|
|
### Reproducibility
|
|
|
|
#### `deterministic` (bool or str)
|
|
Control deterministic behavior. Default: False
|
|
|
|
```python
|
|
import lightning as L
|
|
|
|
# Seed everything
|
|
L.seed_everything(42, workers=True)
|
|
|
|
# Fully deterministic (may impact performance)
|
|
trainer = L.Trainer(deterministic=True)
|
|
|
|
# Warn if non-deterministic operations detected
|
|
trainer = L.Trainer(deterministic="warn")
|
|
```
|
|
|
|
#### `benchmark` (bool)
|
|
Enable cudnn benchmarking for performance. Default: False
|
|
|
|
```python
|
|
trainer = L.Trainer(benchmark=True)
|
|
```
|
|
|
|
### Miscellaneous
|
|
|
|
#### `enable_model_summary` (bool)
|
|
Print model summary before training. Default: True
|
|
|
|
```python
|
|
trainer = L.Trainer(enable_model_summary=False)
|
|
```
|
|
|
|
#### `inference_mode` (bool)
|
|
Use torch.inference_mode() instead of torch.no_grad() for validation/test. Default: True
|
|
|
|
```python
|
|
trainer = L.Trainer(inference_mode=True)
|
|
```
|
|
|
|
#### `profiler` (str or Profiler)
|
|
Profile code for performance optimization. Options: "simple", "advanced", or custom Profiler.
|
|
|
|
```python
|
|
# Simple profiler
|
|
trainer = L.Trainer(profiler="simple")
|
|
|
|
# Advanced profiler
|
|
trainer = L.Trainer(profiler="advanced")
|
|
```
|
|
|
|
## Common Configurations
|
|
|
|
### Basic Training
|
|
```python
|
|
trainer = L.Trainer(
|
|
max_epochs=100,
|
|
accelerator="auto",
|
|
devices="auto"
|
|
)
|
|
trainer.fit(model, train_loader, val_loader)
|
|
```
|
|
|
|
### Multi-GPU Training
|
|
```python
|
|
trainer = L.Trainer(
|
|
max_epochs=100,
|
|
accelerator="gpu",
|
|
devices=4,
|
|
strategy="ddp",
|
|
precision="16-mixed"
|
|
)
|
|
trainer.fit(model, datamodule=dm)
|
|
```
|
|
|
|
### Production Training with Checkpoints
|
|
```python
|
|
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath="checkpoints/",
|
|
filename="{epoch}-{val_loss:.2f}",
|
|
monitor="val_loss",
|
|
mode="min",
|
|
save_top_k=3,
|
|
save_last=True
|
|
)
|
|
|
|
early_stop = EarlyStopping(
|
|
monitor="val_loss",
|
|
patience=10,
|
|
mode="min"
|
|
)
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval="step")
|
|
|
|
trainer = L.Trainer(
|
|
max_epochs=100,
|
|
accelerator="gpu",
|
|
devices=2,
|
|
strategy="ddp",
|
|
precision="16-mixed",
|
|
callbacks=[checkpoint_callback, early_stop, lr_monitor],
|
|
log_every_n_steps=10,
|
|
gradient_clip_val=1.0
|
|
)
|
|
|
|
trainer.fit(model, datamodule=dm)
|
|
```
|
|
|
|
### Debug Configuration
|
|
```python
|
|
trainer = L.Trainer(
|
|
fast_dev_run=True, # Run 1 batch
|
|
accelerator="cpu",
|
|
enable_progress_bar=True,
|
|
log_every_n_steps=1,
|
|
detect_anomaly=True
|
|
)
|
|
trainer.fit(model, train_loader, val_loader)
|
|
```
|
|
|
|
### Research Configuration (Reproducibility)
|
|
```python
|
|
import lightning as L
|
|
|
|
L.seed_everything(42, workers=True)
|
|
|
|
trainer = L.Trainer(
|
|
max_epochs=100,
|
|
accelerator="gpu",
|
|
devices=1,
|
|
deterministic=True,
|
|
benchmark=False,
|
|
precision="32-true"
|
|
)
|
|
trainer.fit(model, datamodule=dm)
|
|
```
|
|
|
|
### Time-Limited Training (Cluster)
|
|
```python
|
|
trainer = L.Trainer(
|
|
max_time={"hours": 23, "minutes": 30}, # SLURM time limit
|
|
max_epochs=1000,
|
|
callbacks=[ModelCheckpoint(save_last=True)]
|
|
)
|
|
trainer.fit(model, datamodule=dm)
|
|
|
|
# Resume from checkpoint
|
|
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
|
|
```
|
|
|
|
### Large Model Training (FSDP)
|
|
```python
|
|
from lightning.pytorch.strategies import FSDPStrategy
|
|
|
|
trainer = L.Trainer(
|
|
max_epochs=100,
|
|
accelerator="gpu",
|
|
devices=8,
|
|
strategy=FSDPStrategy(
|
|
activation_checkpointing_policy={nn.TransformerEncoderLayer},
|
|
cpu_offload=False
|
|
),
|
|
precision="bf16-mixed",
|
|
accumulate_grad_batches=4
|
|
)
|
|
trainer.fit(model, datamodule=dm)
|
|
```
|
|
|
|
## Resuming Training
|
|
|
|
### From Checkpoint
|
|
```python
|
|
# Resume from specific checkpoint
|
|
trainer.fit(model, datamodule=dm, ckpt_path="epoch=10-val_loss=0.23.ckpt")
|
|
|
|
# Resume from last checkpoint
|
|
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
|
|
```
|
|
|
|
### Finding Last Checkpoint
|
|
```python
|
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
|
|
checkpoint_callback = ModelCheckpoint(save_last=True)
|
|
trainer = L.Trainer(callbacks=[checkpoint_callback])
|
|
trainer.fit(model, datamodule=dm)
|
|
|
|
# Get path to last checkpoint
|
|
last_checkpoint = checkpoint_callback.last_model_path
|
|
```
|
|
|
|
## Accessing Trainer from LightningModule
|
|
|
|
Inside a LightningModule, access the Trainer via `self.trainer`:
|
|
|
|
```python
|
|
class MyModel(L.LightningModule):
|
|
def training_step(self, batch, batch_idx):
|
|
# Access trainer properties
|
|
current_epoch = self.trainer.current_epoch
|
|
global_step = self.trainer.global_step
|
|
max_epochs = self.trainer.max_epochs
|
|
|
|
# Access callbacks
|
|
for callback in self.trainer.callbacks:
|
|
if isinstance(callback, ModelCheckpoint):
|
|
print(f"Best model: {callback.best_model_path}")
|
|
|
|
# Access logger
|
|
self.trainer.logger.log_metrics({"custom": value})
|
|
```
|
|
|
|
## Trainer Attributes
|
|
|
|
| Attribute | Description |
|
|
|-----------|-------------|
|
|
| `trainer.current_epoch` | Current epoch (0-indexed) |
|
|
| `trainer.global_step` | Total optimizer steps |
|
|
| `trainer.max_epochs` | Maximum epochs configured |
|
|
| `trainer.max_steps` | Maximum steps configured |
|
|
| `trainer.callbacks` | List of callbacks |
|
|
| `trainer.logger` | Logger instance |
|
|
| `trainer.strategy` | Training strategy |
|
|
| `trainer.estimated_stepping_batches` | Estimated total steps for training |
|
|
|
|
## Best Practices
|
|
|
|
### 1. Start with Fast Dev Run
|
|
Always test with `fast_dev_run=True` before full training:
|
|
|
|
```python
|
|
trainer = L.Trainer(fast_dev_run=True)
|
|
trainer.fit(model, datamodule=dm)
|
|
```
|
|
|
|
### 2. Use Gradient Clipping
|
|
Prevent gradient explosions:
|
|
|
|
```python
|
|
trainer = L.Trainer(gradient_clip_val=1.0, gradient_clip_algorithm="norm")
|
|
```
|
|
|
|
### 3. Enable Mixed Precision
|
|
Speed up training on modern GPUs:
|
|
|
|
```python
|
|
trainer = L.Trainer(precision="16-mixed") # or "bf16-mixed" for A100+
|
|
```
|
|
|
|
### 4. Save Checkpoints Properly
|
|
Always save the last checkpoint for resuming:
|
|
|
|
```python
|
|
checkpoint_callback = ModelCheckpoint(
|
|
save_top_k=3,
|
|
save_last=True,
|
|
monitor="val_loss"
|
|
)
|
|
```
|
|
|
|
### 5. Monitor Learning Rate
|
|
Track LR changes with LearningRateMonitor:
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
|
|
trainer = L.Trainer(callbacks=[LearningRateMonitor(logging_interval="step")])
|
|
```
|
|
|
|
### 6. Use DataModule for Reproducibility
|
|
Encapsulate data logic in a DataModule:
|
|
|
|
```python
|
|
# Better than passing DataLoaders directly
|
|
trainer.fit(model, datamodule=dm)
|
|
```
|
|
|
|
### 7. Set Deterministic for Research
|
|
Ensure reproducibility for publications:
|
|
|
|
```python
|
|
L.seed_everything(42, workers=True)
|
|
trainer = L.Trainer(deterministic=True)
|
|
```
|