14 KiB
Trainer - Comprehensive Guide
Overview
The Trainer automates training workflows after organizing PyTorch code into a LightningModule. It handles loop details, device management, callbacks, gradient operations, checkpointing, and distributed training automatically.
Core Purpose
The Trainer manages:
- Automatically enabling/disabling gradients
- Running training, validation, and test dataloaders
- Calling callbacks at appropriate times
- Placing batches on correct devices
- Orchestrating distributed training
- Progress bars and logging
- Checkpointing and early stopping
Main Methods
fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None)
Runs the full training routine including optional validation.
Parameters:
model- LightningModule to traintrain_dataloaders- Training DataLoader(s)val_dataloaders- Optional validation DataLoader(s)datamodule- Optional LightningDataModule (replaces dataloaders)
Examples:
# With DataLoaders
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)
# With DataModule
trainer.fit(model, datamodule=dm)
# Continue training from checkpoint
trainer.fit(model, train_loader, ckpt_path="checkpoint.ckpt")
validate(model=None, dataloaders=None, datamodule=None)
Run validation loop without training.
Example:
trainer = L.Trainer()
trainer.validate(model, val_loader)
test(model=None, dataloaders=None, datamodule=None)
Run test loop. Only use before publishing results.
Example:
trainer = L.Trainer()
trainer.test(model, test_loader)
predict(model=None, dataloaders=None, datamodule=None)
Run inference on data and return predictions.
Example:
trainer = L.Trainer()
predictions = trainer.predict(model, predict_loader)
Essential Parameters
Training Duration
max_epochs (int)
Maximum number of epochs to train. Default: 1000
trainer = L.Trainer(max_epochs=100)
min_epochs (int)
Minimum number of epochs to train. Default: None
trainer = L.Trainer(min_epochs=10, max_epochs=100)
max_steps (int)
Maximum number of optimizer steps. Overrides max_epochs. Default: -1 (unlimited)
trainer = L.Trainer(max_steps=10000)
max_time (str or dict)
Maximum training time. Useful for time-limited clusters.
# String format
trainer = L.Trainer(max_time="00:12:00:00") # 12 hours
# Dictionary format
trainer = L.Trainer(max_time={"days": 1, "hours": 6})
Hardware Configuration
accelerator (str or Accelerator)
Hardware to use: "cpu", "gpu", "tpu", "ipu", "hpu", "mps", or "auto". Default: "auto"
trainer = L.Trainer(accelerator="gpu")
trainer = L.Trainer(accelerator="auto") # Auto-detect available hardware
devices (int, list, or str)
Number or list of device indices to use.
# Use 2 GPUs
trainer = L.Trainer(devices=2, accelerator="gpu")
# Use specific GPUs
trainer = L.Trainer(devices=[0, 2], accelerator="gpu")
# Use all available devices
trainer = L.Trainer(devices="auto", accelerator="gpu")
# CPU with 4 cores
trainer = L.Trainer(devices=4, accelerator="cpu")
strategy (str or Strategy)
Distributed training strategy: "ddp", "ddp_spawn", "fsdp", "deepspeed", etc. Default: "auto"
# Data Distributed Parallel
trainer = L.Trainer(strategy="ddp", accelerator="gpu", devices=4)
# Fully Sharded Data Parallel
trainer = L.Trainer(strategy="fsdp", accelerator="gpu", devices=4)
# DeepSpeed
trainer = L.Trainer(strategy="deepspeed_stage_2", accelerator="gpu", devices=4)
precision (str or int)
Floating point precision: "32-true", "16-mixed", "bf16-mixed", "64-true", etc.
# Mixed precision (FP16)
trainer = L.Trainer(precision="16-mixed")
# BFloat16 mixed precision
trainer = L.Trainer(precision="bf16-mixed")
# Full precision
trainer = L.Trainer(precision="32-true")
Optimization Configuration
gradient_clip_val (float)
Gradient clipping value. Default: None
# Clip gradients by norm
trainer = L.Trainer(gradient_clip_val=0.5)
gradient_clip_algorithm (str)
Gradient clipping algorithm: "norm" or "value". Default: "norm"
trainer = L.Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="norm")
accumulate_grad_batches (int or dict)
Accumulate gradients over N batches before optimizer step.
# Accumulate over 4 batches
trainer = L.Trainer(accumulate_grad_batches=4)
# Different accumulation per epoch
trainer = L.Trainer(accumulate_grad_batches={0: 4, 5: 2, 10: 1})
Validation Configuration
check_val_every_n_epoch (int)
Run validation every N epochs. Default: 1
trainer = L.Trainer(check_val_every_n_epoch=10)
val_check_interval (int or float)
How often to check validation within a training epoch.
# Check validation every 0.25 of training epoch
trainer = L.Trainer(val_check_interval=0.25)
# Check validation every 100 training batches
trainer = L.Trainer(val_check_interval=100)
limit_val_batches (int or float)
Limit validation batches.
# Use only 10% of validation data
trainer = L.Trainer(limit_val_batches=0.1)
# Use only 50 validation batches
trainer = L.Trainer(limit_val_batches=50)
# Disable validation
trainer = L.Trainer(limit_val_batches=0)
num_sanity_val_steps (int)
Number of validation batches to run before training starts. Default: 2
# Skip sanity check
trainer = L.Trainer(num_sanity_val_steps=0)
# Run 5 sanity validation steps
trainer = L.Trainer(num_sanity_val_steps=5)
Logging and Progress
logger (Logger or list or bool)
Logger(s) to use for experiment tracking.
from lightning.pytorch import loggers as pl_loggers
# TensorBoard logger
tb_logger = pl_loggers.TensorBoardLogger("logs/")
trainer = L.Trainer(logger=tb_logger)
# Multiple loggers
wandb_logger = pl_loggers.WandbLogger(project="my-project")
trainer = L.Trainer(logger=[tb_logger, wandb_logger])
# Disable logging
trainer = L.Trainer(logger=False)
log_every_n_steps (int)
How often to log within training steps. Default: 50
trainer = L.Trainer(log_every_n_steps=10)
enable_progress_bar (bool)
Show progress bar. Default: True
trainer = L.Trainer(enable_progress_bar=False)
Callbacks
callbacks (list)
List of callbacks to use during training.
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
save_top_k=3,
mode="min"
)
early_stop_callback = EarlyStopping(
monitor="val_loss",
patience=5,
mode="min"
)
trainer = L.Trainer(callbacks=[checkpoint_callback, early_stop_callback])
Checkpointing
default_root_dir (str)
Default directory for logs and checkpoints. Default: current working directory
trainer = L.Trainer(default_root_dir="./experiments/")
enable_checkpointing (bool)
Enable automatic checkpointing. Default: True
trainer = L.Trainer(enable_checkpointing=True)
Debugging
fast_dev_run (bool or int)
Run a single batch (or N batches) through train/val/test for debugging.
# Run 1 batch of train/val/test
trainer = L.Trainer(fast_dev_run=True)
# Run 5 batches of train/val/test
trainer = L.Trainer(fast_dev_run=5)
limit_train_batches (int or float)
Limit training batches.
# Use only 25% of training data
trainer = L.Trainer(limit_train_batches=0.25)
# Use only 100 training batches
trainer = L.Trainer(limit_train_batches=100)
limit_test_batches (int or float)
Limit test batches.
trainer = L.Trainer(limit_test_batches=0.5)
overfit_batches (int or float)
Overfit on a subset of data for debugging.
# Overfit on 10 batches
trainer = L.Trainer(overfit_batches=10)
# Overfit on 1% of data
trainer = L.Trainer(overfit_batches=0.01)
detect_anomaly (bool)
Enable PyTorch anomaly detection for debugging NaNs. Default: False
trainer = L.Trainer(detect_anomaly=True)
Reproducibility
deterministic (bool or str)
Control deterministic behavior. Default: False
import lightning as L
# Seed everything
L.seed_everything(42, workers=True)
# Fully deterministic (may impact performance)
trainer = L.Trainer(deterministic=True)
# Warn if non-deterministic operations detected
trainer = L.Trainer(deterministic="warn")
benchmark (bool)
Enable cudnn benchmarking for performance. Default: False
trainer = L.Trainer(benchmark=True)
Miscellaneous
enable_model_summary (bool)
Print model summary before training. Default: True
trainer = L.Trainer(enable_model_summary=False)
inference_mode (bool)
Use torch.inference_mode() instead of torch.no_grad() for validation/test. Default: True
trainer = L.Trainer(inference_mode=True)
profiler (str or Profiler)
Profile code for performance optimization. Options: "simple", "advanced", or custom Profiler.
# Simple profiler
trainer = L.Trainer(profiler="simple")
# Advanced profiler
trainer = L.Trainer(profiler="advanced")
Common Configurations
Basic Training
trainer = L.Trainer(
max_epochs=100,
accelerator="auto",
devices="auto"
)
trainer.fit(model, train_loader, val_loader)
Multi-GPU Training
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=4,
strategy="ddp",
precision="16-mixed"
)
trainer.fit(model, datamodule=dm)
Production Training with Checkpoints
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3,
save_last=True
)
early_stop = EarlyStopping(
monitor="val_loss",
patience=10,
mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=2,
strategy="ddp",
precision="16-mixed",
callbacks=[checkpoint_callback, early_stop, lr_monitor],
log_every_n_steps=10,
gradient_clip_val=1.0
)
trainer.fit(model, datamodule=dm)
Debug Configuration
trainer = L.Trainer(
fast_dev_run=True, # Run 1 batch
accelerator="cpu",
enable_progress_bar=True,
log_every_n_steps=1,
detect_anomaly=True
)
trainer.fit(model, train_loader, val_loader)
Research Configuration (Reproducibility)
import lightning as L
L.seed_everything(42, workers=True)
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=1,
deterministic=True,
benchmark=False,
precision="32-true"
)
trainer.fit(model, datamodule=dm)
Time-Limited Training (Cluster)
trainer = L.Trainer(
max_time={"hours": 23, "minutes": 30}, # SLURM time limit
max_epochs=1000,
callbacks=[ModelCheckpoint(save_last=True)]
)
trainer.fit(model, datamodule=dm)
# Resume from checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
Large Model Training (FSDP)
from lightning.pytorch.strategies import FSDPStrategy
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=8,
strategy=FSDPStrategy(
activation_checkpointing_policy={nn.TransformerEncoderLayer},
cpu_offload=False
),
precision="bf16-mixed",
accumulate_grad_batches=4
)
trainer.fit(model, datamodule=dm)
Resuming Training
From Checkpoint
# Resume from specific checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="epoch=10-val_loss=0.23.ckpt")
# Resume from last checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
Finding Last Checkpoint
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(save_last=True)
trainer = L.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=dm)
# Get path to last checkpoint
last_checkpoint = checkpoint_callback.last_model_path
Accessing Trainer from LightningModule
Inside a LightningModule, access the Trainer via self.trainer:
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# Access trainer properties
current_epoch = self.trainer.current_epoch
global_step = self.trainer.global_step
max_epochs = self.trainer.max_epochs
# Access callbacks
for callback in self.trainer.callbacks:
if isinstance(callback, ModelCheckpoint):
print(f"Best model: {callback.best_model_path}")
# Access logger
self.trainer.logger.log_metrics({"custom": value})
Trainer Attributes
| Attribute | Description |
|---|---|
trainer.current_epoch |
Current epoch (0-indexed) |
trainer.global_step |
Total optimizer steps |
trainer.max_epochs |
Maximum epochs configured |
trainer.max_steps |
Maximum steps configured |
trainer.callbacks |
List of callbacks |
trainer.logger |
Logger instance |
trainer.strategy |
Training strategy |
trainer.estimated_stepping_batches |
Estimated total steps for training |
Best Practices
1. Start with Fast Dev Run
Always test with fast_dev_run=True before full training:
trainer = L.Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)
2. Use Gradient Clipping
Prevent gradient explosions:
trainer = L.Trainer(gradient_clip_val=1.0, gradient_clip_algorithm="norm")
3. Enable Mixed Precision
Speed up training on modern GPUs:
trainer = L.Trainer(precision="16-mixed") # or "bf16-mixed" for A100+
4. Save Checkpoints Properly
Always save the last checkpoint for resuming:
checkpoint_callback = ModelCheckpoint(
save_top_k=3,
save_last=True,
monitor="val_loss"
)
5. Monitor Learning Rate
Track LR changes with LearningRateMonitor:
from lightning.pytorch.callbacks import LearningRateMonitor
trainer = L.Trainer(callbacks=[LearningRateMonitor(logging_interval="step")])
6. Use DataModule for Reproducibility
Encapsulate data logic in a DataModule:
# Better than passing DataLoaders directly
trainer.fit(model, datamodule=dm)
7. Set Deterministic for Research
Ensure reproducibility for publications:
L.seed_everything(42, workers=True)
trainer = L.Trainer(deterministic=True)