# Trainer - Comprehensive Guide ## Overview The Trainer automates training workflows after organizing PyTorch code into a LightningModule. It handles loop details, device management, callbacks, gradient operations, checkpointing, and distributed training automatically. ## Core Purpose The Trainer manages: - Automatically enabling/disabling gradients - Running training, validation, and test dataloaders - Calling callbacks at appropriate times - Placing batches on correct devices - Orchestrating distributed training - Progress bars and logging - Checkpointing and early stopping ## Main Methods ### `fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None)` Runs the full training routine including optional validation. **Parameters:** - `model` - LightningModule to train - `train_dataloaders` - Training DataLoader(s) - `val_dataloaders` - Optional validation DataLoader(s) - `datamodule` - Optional LightningDataModule (replaces dataloaders) **Examples:** ```python # With DataLoaders trainer = L.Trainer(max_epochs=10) trainer.fit(model, train_loader, val_loader) # With DataModule trainer.fit(model, datamodule=dm) # Continue training from checkpoint trainer.fit(model, train_loader, ckpt_path="checkpoint.ckpt") ``` ### `validate(model=None, dataloaders=None, datamodule=None)` Run validation loop without training. **Example:** ```python trainer = L.Trainer() trainer.validate(model, val_loader) ``` ### `test(model=None, dataloaders=None, datamodule=None)` Run test loop. Only use before publishing results. **Example:** ```python trainer = L.Trainer() trainer.test(model, test_loader) ``` ### `predict(model=None, dataloaders=None, datamodule=None)` Run inference on data and return predictions. **Example:** ```python trainer = L.Trainer() predictions = trainer.predict(model, predict_loader) ``` ## Essential Parameters ### Training Duration #### `max_epochs` (int) Maximum number of epochs to train. Default: 1000 ```python trainer = L.Trainer(max_epochs=100) ``` #### `min_epochs` (int) Minimum number of epochs to train. Default: None ```python trainer = L.Trainer(min_epochs=10, max_epochs=100) ``` #### `max_steps` (int) Maximum number of optimizer steps. Overrides max_epochs. Default: -1 (unlimited) ```python trainer = L.Trainer(max_steps=10000) ``` #### `max_time` (str or dict) Maximum training time. Useful for time-limited clusters. ```python # String format trainer = L.Trainer(max_time="00:12:00:00") # 12 hours # Dictionary format trainer = L.Trainer(max_time={"days": 1, "hours": 6}) ``` ### Hardware Configuration #### `accelerator` (str or Accelerator) Hardware to use: "cpu", "gpu", "tpu", "ipu", "hpu", "mps", or "auto". Default: "auto" ```python trainer = L.Trainer(accelerator="gpu") trainer = L.Trainer(accelerator="auto") # Auto-detect available hardware ``` #### `devices` (int, list, or str) Number or list of device indices to use. ```python # Use 2 GPUs trainer = L.Trainer(devices=2, accelerator="gpu") # Use specific GPUs trainer = L.Trainer(devices=[0, 2], accelerator="gpu") # Use all available devices trainer = L.Trainer(devices="auto", accelerator="gpu") # CPU with 4 cores trainer = L.Trainer(devices=4, accelerator="cpu") ``` #### `strategy` (str or Strategy) Distributed training strategy: "ddp", "ddp_spawn", "fsdp", "deepspeed", etc. Default: "auto" ```python # Data Distributed Parallel trainer = L.Trainer(strategy="ddp", accelerator="gpu", devices=4) # Fully Sharded Data Parallel trainer = L.Trainer(strategy="fsdp", accelerator="gpu", devices=4) # DeepSpeed trainer = L.Trainer(strategy="deepspeed_stage_2", accelerator="gpu", devices=4) ``` #### `precision` (str or int) Floating point precision: "32-true", "16-mixed", "bf16-mixed", "64-true", etc. ```python # Mixed precision (FP16) trainer = L.Trainer(precision="16-mixed") # BFloat16 mixed precision trainer = L.Trainer(precision="bf16-mixed") # Full precision trainer = L.Trainer(precision="32-true") ``` ### Optimization Configuration #### `gradient_clip_val` (float) Gradient clipping value. Default: None ```python # Clip gradients by norm trainer = L.Trainer(gradient_clip_val=0.5) ``` #### `gradient_clip_algorithm` (str) Gradient clipping algorithm: "norm" or "value". Default: "norm" ```python trainer = L.Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="norm") ``` #### `accumulate_grad_batches` (int or dict) Accumulate gradients over N batches before optimizer step. ```python # Accumulate over 4 batches trainer = L.Trainer(accumulate_grad_batches=4) # Different accumulation per epoch trainer = L.Trainer(accumulate_grad_batches={0: 4, 5: 2, 10: 1}) ``` ### Validation Configuration #### `check_val_every_n_epoch` (int) Run validation every N epochs. Default: 1 ```python trainer = L.Trainer(check_val_every_n_epoch=10) ``` #### `val_check_interval` (int or float) How often to check validation within a training epoch. ```python # Check validation every 0.25 of training epoch trainer = L.Trainer(val_check_interval=0.25) # Check validation every 100 training batches trainer = L.Trainer(val_check_interval=100) ``` #### `limit_val_batches` (int or float) Limit validation batches. ```python # Use only 10% of validation data trainer = L.Trainer(limit_val_batches=0.1) # Use only 50 validation batches trainer = L.Trainer(limit_val_batches=50) # Disable validation trainer = L.Trainer(limit_val_batches=0) ``` #### `num_sanity_val_steps` (int) Number of validation batches to run before training starts. Default: 2 ```python # Skip sanity check trainer = L.Trainer(num_sanity_val_steps=0) # Run 5 sanity validation steps trainer = L.Trainer(num_sanity_val_steps=5) ``` ### Logging and Progress #### `logger` (Logger or list or bool) Logger(s) to use for experiment tracking. ```python from lightning.pytorch import loggers as pl_loggers # TensorBoard logger tb_logger = pl_loggers.TensorBoardLogger("logs/") trainer = L.Trainer(logger=tb_logger) # Multiple loggers wandb_logger = pl_loggers.WandbLogger(project="my-project") trainer = L.Trainer(logger=[tb_logger, wandb_logger]) # Disable logging trainer = L.Trainer(logger=False) ``` #### `log_every_n_steps` (int) How often to log within training steps. Default: 50 ```python trainer = L.Trainer(log_every_n_steps=10) ``` #### `enable_progress_bar` (bool) Show progress bar. Default: True ```python trainer = L.Trainer(enable_progress_bar=False) ``` ### Callbacks #### `callbacks` (list) List of callbacks to use during training. ```python from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping checkpoint_callback = ModelCheckpoint( monitor="val_loss", save_top_k=3, mode="min" ) early_stop_callback = EarlyStopping( monitor="val_loss", patience=5, mode="min" ) trainer = L.Trainer(callbacks=[checkpoint_callback, early_stop_callback]) ``` ### Checkpointing #### `default_root_dir` (str) Default directory for logs and checkpoints. Default: current working directory ```python trainer = L.Trainer(default_root_dir="./experiments/") ``` #### `enable_checkpointing` (bool) Enable automatic checkpointing. Default: True ```python trainer = L.Trainer(enable_checkpointing=True) ``` ### Debugging #### `fast_dev_run` (bool or int) Run a single batch (or N batches) through train/val/test for debugging. ```python # Run 1 batch of train/val/test trainer = L.Trainer(fast_dev_run=True) # Run 5 batches of train/val/test trainer = L.Trainer(fast_dev_run=5) ``` #### `limit_train_batches` (int or float) Limit training batches. ```python # Use only 25% of training data trainer = L.Trainer(limit_train_batches=0.25) # Use only 100 training batches trainer = L.Trainer(limit_train_batches=100) ``` #### `limit_test_batches` (int or float) Limit test batches. ```python trainer = L.Trainer(limit_test_batches=0.5) ``` #### `overfit_batches` (int or float) Overfit on a subset of data for debugging. ```python # Overfit on 10 batches trainer = L.Trainer(overfit_batches=10) # Overfit on 1% of data trainer = L.Trainer(overfit_batches=0.01) ``` #### `detect_anomaly` (bool) Enable PyTorch anomaly detection for debugging NaNs. Default: False ```python trainer = L.Trainer(detect_anomaly=True) ``` ### Reproducibility #### `deterministic` (bool or str) Control deterministic behavior. Default: False ```python import lightning as L # Seed everything L.seed_everything(42, workers=True) # Fully deterministic (may impact performance) trainer = L.Trainer(deterministic=True) # Warn if non-deterministic operations detected trainer = L.Trainer(deterministic="warn") ``` #### `benchmark` (bool) Enable cudnn benchmarking for performance. Default: False ```python trainer = L.Trainer(benchmark=True) ``` ### Miscellaneous #### `enable_model_summary` (bool) Print model summary before training. Default: True ```python trainer = L.Trainer(enable_model_summary=False) ``` #### `inference_mode` (bool) Use torch.inference_mode() instead of torch.no_grad() for validation/test. Default: True ```python trainer = L.Trainer(inference_mode=True) ``` #### `profiler` (str or Profiler) Profile code for performance optimization. Options: "simple", "advanced", or custom Profiler. ```python # Simple profiler trainer = L.Trainer(profiler="simple") # Advanced profiler trainer = L.Trainer(profiler="advanced") ``` ## Common Configurations ### Basic Training ```python trainer = L.Trainer( max_epochs=100, accelerator="auto", devices="auto" ) trainer.fit(model, train_loader, val_loader) ``` ### Multi-GPU Training ```python trainer = L.Trainer( max_epochs=100, accelerator="gpu", devices=4, strategy="ddp", precision="16-mixed" ) trainer.fit(model, datamodule=dm) ``` ### Production Training with Checkpoints ```python from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor checkpoint_callback = ModelCheckpoint( dirpath="checkpoints/", filename="{epoch}-{val_loss:.2f}", monitor="val_loss", mode="min", save_top_k=3, save_last=True ) early_stop = EarlyStopping( monitor="val_loss", patience=10, mode="min" ) lr_monitor = LearningRateMonitor(logging_interval="step") trainer = L.Trainer( max_epochs=100, accelerator="gpu", devices=2, strategy="ddp", precision="16-mixed", callbacks=[checkpoint_callback, early_stop, lr_monitor], log_every_n_steps=10, gradient_clip_val=1.0 ) trainer.fit(model, datamodule=dm) ``` ### Debug Configuration ```python trainer = L.Trainer( fast_dev_run=True, # Run 1 batch accelerator="cpu", enable_progress_bar=True, log_every_n_steps=1, detect_anomaly=True ) trainer.fit(model, train_loader, val_loader) ``` ### Research Configuration (Reproducibility) ```python import lightning as L L.seed_everything(42, workers=True) trainer = L.Trainer( max_epochs=100, accelerator="gpu", devices=1, deterministic=True, benchmark=False, precision="32-true" ) trainer.fit(model, datamodule=dm) ``` ### Time-Limited Training (Cluster) ```python trainer = L.Trainer( max_time={"hours": 23, "minutes": 30}, # SLURM time limit max_epochs=1000, callbacks=[ModelCheckpoint(save_last=True)] ) trainer.fit(model, datamodule=dm) # Resume from checkpoint trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt") ``` ### Large Model Training (FSDP) ```python from lightning.pytorch.strategies import FSDPStrategy trainer = L.Trainer( max_epochs=100, accelerator="gpu", devices=8, strategy=FSDPStrategy( activation_checkpointing_policy={nn.TransformerEncoderLayer}, cpu_offload=False ), precision="bf16-mixed", accumulate_grad_batches=4 ) trainer.fit(model, datamodule=dm) ``` ## Resuming Training ### From Checkpoint ```python # Resume from specific checkpoint trainer.fit(model, datamodule=dm, ckpt_path="epoch=10-val_loss=0.23.ckpt") # Resume from last checkpoint trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt") ``` ### Finding Last Checkpoint ```python from lightning.pytorch.callbacks import ModelCheckpoint checkpoint_callback = ModelCheckpoint(save_last=True) trainer = L.Trainer(callbacks=[checkpoint_callback]) trainer.fit(model, datamodule=dm) # Get path to last checkpoint last_checkpoint = checkpoint_callback.last_model_path ``` ## Accessing Trainer from LightningModule Inside a LightningModule, access the Trainer via `self.trainer`: ```python class MyModel(L.LightningModule): def training_step(self, batch, batch_idx): # Access trainer properties current_epoch = self.trainer.current_epoch global_step = self.trainer.global_step max_epochs = self.trainer.max_epochs # Access callbacks for callback in self.trainer.callbacks: if isinstance(callback, ModelCheckpoint): print(f"Best model: {callback.best_model_path}") # Access logger self.trainer.logger.log_metrics({"custom": value}) ``` ## Trainer Attributes | Attribute | Description | |-----------|-------------| | `trainer.current_epoch` | Current epoch (0-indexed) | | `trainer.global_step` | Total optimizer steps | | `trainer.max_epochs` | Maximum epochs configured | | `trainer.max_steps` | Maximum steps configured | | `trainer.callbacks` | List of callbacks | | `trainer.logger` | Logger instance | | `trainer.strategy` | Training strategy | | `trainer.estimated_stepping_batches` | Estimated total steps for training | ## Best Practices ### 1. Start with Fast Dev Run Always test with `fast_dev_run=True` before full training: ```python trainer = L.Trainer(fast_dev_run=True) trainer.fit(model, datamodule=dm) ``` ### 2. Use Gradient Clipping Prevent gradient explosions: ```python trainer = L.Trainer(gradient_clip_val=1.0, gradient_clip_algorithm="norm") ``` ### 3. Enable Mixed Precision Speed up training on modern GPUs: ```python trainer = L.Trainer(precision="16-mixed") # or "bf16-mixed" for A100+ ``` ### 4. Save Checkpoints Properly Always save the last checkpoint for resuming: ```python checkpoint_callback = ModelCheckpoint( save_top_k=3, save_last=True, monitor="val_loss" ) ``` ### 5. Monitor Learning Rate Track LR changes with LearningRateMonitor: ```python from lightning.pytorch.callbacks import LearningRateMonitor trainer = L.Trainer(callbacks=[LearningRateMonitor(logging_interval="step")]) ``` ### 6. Use DataModule for Reproducibility Encapsulate data logic in a DataModule: ```python # Better than passing DataLoaders directly trainer.fit(model, datamodule=dm) ``` ### 7. Set Deterministic for Research Ensure reproducibility for publications: ```python L.seed_everything(42, workers=True) trainer = L.Trainer(deterministic=True) ```