Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 08:30:10 +08:00
commit f0bd18fb4e
824 changed files with 331919 additions and 0 deletions

View File

@@ -0,0 +1,724 @@
# Best Practices - PyTorch Lightning
## Code Organization
### 1. Separate Research from Engineering
**Good:**
```python
class MyModel(L.LightningModule):
# Research code (what the model does)
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
return loss
# Engineering code (how to train) - in Trainer
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=4,
strategy="ddp"
)
```
**Bad:**
```python
# Mixing research and engineering logic
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Don't do device management manually
loss = loss.cuda()
# Don't do optimizer steps manually (unless manual optimization)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
```
### 2. Use LightningDataModule
**Good:**
```python
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def prepare_data(self):
# Download data once
download_data(self.data_dir)
def setup(self, stage):
# Load data per-process
self.train_dataset = MyDataset(self.data_dir, split='train')
self.val_dataset = MyDataset(self.data_dir, split='val')
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
# Reusable and shareable
dm = MyDataModule("./data", batch_size=32)
trainer.fit(model, datamodule=dm)
```
**Bad:**
```python
# Scattered data logic
train_dataset = load_data()
val_dataset = load_data()
train_loader = DataLoader(train_dataset, ...)
val_loader = DataLoader(val_dataset, ...)
trainer.fit(model, train_loader, val_loader)
```
### 3. Keep Models Modular
```python
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(...)
def forward(self, x):
return self.layers(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(...)
def forward(self, x):
return self.layers(x)
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
```
## Device Agnosticism
### 1. Never Use Explicit CUDA Calls
**Bad:**
```python
x = x.cuda()
model = model.cuda()
torch.cuda.set_device(0)
```
**Good:**
```python
# Inside LightningModule
x = x.to(self.device)
# Or let Lightning handle it automatically
def training_step(self, batch, batch_idx):
x, y = batch # Already on correct device
return loss
```
### 2. Use `self.device` Property
```python
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# Create tensors on correct device
noise = torch.randn(batch.size(0), 100).to(self.device)
# Or use type_as
noise = torch.randn(batch.size(0), 100).type_as(batch)
```
### 3. Register Buffers for Non-Parameters
```python
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Register buffers (automatically moved to correct device)
self.register_buffer("running_mean", torch.zeros(100))
def forward(self, x):
# self.running_mean is automatically on correct device
return x - self.running_mean
```
## Hyperparameter Management
### 1. Always Use `save_hyperparameters()`
**Good:**
```python
class MyModel(L.LightningModule):
def __init__(self, learning_rate, hidden_dim, dropout):
super().__init__()
self.save_hyperparameters() # Saves all arguments
# Access via self.hparams
self.model = nn.Linear(self.hparams.hidden_dim, 10)
# Load from checkpoint with saved hparams
model = MyModel.load_from_checkpoint("checkpoint.ckpt")
print(model.hparams.learning_rate) # Original value preserved
```
**Bad:**
```python
class MyModel(L.LightningModule):
def __init__(self, learning_rate, hidden_dim, dropout):
super().__init__()
self.learning_rate = learning_rate # Manual tracking
self.hidden_dim = hidden_dim
```
### 2. Ignore Specific Arguments
```python
class MyModel(L.LightningModule):
def __init__(self, lr, model, dataset):
super().__init__()
# Don't save 'model' and 'dataset' (not serializable)
self.save_hyperparameters(ignore=['model', 'dataset'])
self.model = model
self.dataset = dataset
```
### 3. Use Hyperparameters in `configure_optimizers()`
```python
def configure_optimizers(self):
# Use saved hyperparameters
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay
)
return optimizer
```
## Logging Best Practices
### 1. Log Both Step and Epoch Metrics
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Log per-step for detailed monitoring
# Log per-epoch for aggregated view
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
```
### 2. Use Structured Logging
```python
def training_step(self, batch, batch_idx):
# Organize with prefixes
self.log("train/loss", loss)
self.log("train/acc", acc)
self.log("train/f1", f1)
def validation_step(self, batch, batch_idx):
self.log("val/loss", loss)
self.log("val/acc", acc)
self.log("val/f1", f1)
```
### 3. Sync Metrics in Distributed Training
```python
def validation_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# IMPORTANT: sync_dist=True for proper aggregation across GPUs
self.log("val_loss", loss, sync_dist=True)
```
### 4. Monitor Learning Rate
```python
from lightning.pytorch.callbacks import LearningRateMonitor
trainer = L.Trainer(
callbacks=[LearningRateMonitor(logging_interval="step")]
)
```
## Reproducibility
### 1. Seed Everything
```python
import lightning as L
# Set seed for reproducibility
L.seed_everything(42, workers=True)
trainer = L.Trainer(
deterministic=True, # Use deterministic algorithms
benchmark=False # Disable cudnn benchmarking
)
```
### 2. Avoid Non-Deterministic Operations
```python
# Bad: Non-deterministic
torch.use_deterministic_algorithms(False)
# Good: Deterministic
torch.use_deterministic_algorithms(True)
```
### 3. Log Random State
```python
def on_save_checkpoint(self, checkpoint):
# Save random states
checkpoint['rng_state'] = {
'torch': torch.get_rng_state(),
'numpy': np.random.get_state(),
'python': random.getstate()
}
def on_load_checkpoint(self, checkpoint):
# Restore random states
if 'rng_state' in checkpoint:
torch.set_rng_state(checkpoint['rng_state']['torch'])
np.random.set_state(checkpoint['rng_state']['numpy'])
random.setstate(checkpoint['rng_state']['python'])
```
## Debugging
### 1. Use `fast_dev_run`
```python
# Test with 1 batch before full training
trainer = L.Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)
```
### 2. Limit Training Data
```python
# Use only 10% of data for quick iteration
trainer = L.Trainer(
limit_train_batches=0.1,
limit_val_batches=0.1
)
```
### 3. Enable Anomaly Detection
```python
# Detect NaN/Inf in gradients
trainer = L.Trainer(detect_anomaly=True)
```
### 4. Overfit on Small Batch
```python
# Overfit on 10 batches to verify model capacity
trainer = L.Trainer(overfit_batches=10)
```
### 5. Profile Code
```python
# Find performance bottlenecks
trainer = L.Trainer(profiler="simple") # or "advanced"
```
## Memory Optimization
### 1. Use Mixed Precision
```python
# FP16/BF16 mixed precision for memory savings and speed
trainer = L.Trainer(
precision="16-mixed", # V100, T4
# or
precision="bf16-mixed" # A100, H100
)
```
### 2. Gradient Accumulation
```python
# Simulate larger batch size without memory increase
trainer = L.Trainer(
accumulate_grad_batches=4 # Accumulate over 4 batches
)
```
### 3. Gradient Checkpointing
```python
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = transformers.AutoModel.from_pretrained("bert-base")
# Enable gradient checkpointing
self.model.gradient_checkpointing_enable()
```
### 4. Clear Cache
```python
def on_train_epoch_end(self):
# Clear collected outputs to free memory
self.training_step_outputs.clear()
# Clear CUDA cache if needed
if torch.cuda.is_available():
torch.cuda.empty_cache()
```
### 5. Use Efficient Data Types
```python
# Use appropriate precision
# FP32 for stability, FP16/BF16 for speed/memory
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Use bfloat16 for better numerical stability than fp16
self.model = MyTransformer().to(torch.bfloat16)
```
## Training Stability
### 1. Gradient Clipping
```python
# Prevent gradient explosion
trainer = L.Trainer(
gradient_clip_val=1.0,
gradient_clip_algorithm="norm" # or "value"
)
```
### 2. Learning Rate Warmup
```python
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=1e-2,
total_steps=self.trainer.estimated_stepping_batches,
pct_start=0.1 # 10% warmup
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step"
}
}
```
### 3. Monitor Gradients
```python
class MyModel(L.LightningModule):
def on_after_backward(self):
# Log gradient norms
for name, param in self.named_parameters():
if param.grad is not None:
self.log(f"grad_norm/{name}", param.grad.norm())
```
### 4. Use EarlyStopping
```python
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor="val_loss",
patience=10,
mode="min",
verbose=True
)
trainer = L.Trainer(callbacks=[early_stop])
```
## Checkpointing
### 1. Save Top-K and Last
```python
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3, # Keep best 3
save_last=True # Always save last for resuming
)
trainer = L.Trainer(callbacks=[checkpoint_callback])
```
### 2. Resume Training
```python
# Resume from last checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
# Resume from specific checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="epoch=10-val_loss=0.23.ckpt")
```
### 3. Custom Checkpoint State
```python
def on_save_checkpoint(self, checkpoint):
# Add custom state
checkpoint['custom_data'] = self.custom_data
checkpoint['epoch_metrics'] = self.metrics
def on_load_checkpoint(self, checkpoint):
# Restore custom state
self.custom_data = checkpoint.get('custom_data', {})
self.metrics = checkpoint.get('epoch_metrics', [])
```
## Testing
### 1. Separate Train and Test
```python
# Train
trainer = L.Trainer(max_epochs=100)
trainer.fit(model, datamodule=dm)
# Test ONLY ONCE before publishing
trainer.test(model, datamodule=dm)
```
### 2. Use Validation for Model Selection
```python
# Use validation for hyperparameter tuning
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
trainer = L.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=dm)
# Load best model
best_model = MyModel.load_from_checkpoint(checkpoint_callback.best_model_path)
# Test only once with best model
trainer.test(best_model, datamodule=dm)
```
## Code Quality
### 1. Type Hints
```python
from typing import Any, Dict, Tuple
import torch
from torch import Tensor
class MyModel(L.LightningModule):
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
x, y = batch
loss = self.compute_loss(x, y)
return loss
def configure_optimizers(self) -> Dict[str, Any]:
optimizer = torch.optim.Adam(self.parameters())
return {"optimizer": optimizer}
```
### 2. Docstrings
```python
class MyModel(L.LightningModule):
"""
My awesome model for image classification.
Args:
num_classes: Number of output classes
learning_rate: Learning rate for optimizer
hidden_dim: Hidden dimension size
"""
def __init__(self, num_classes: int, learning_rate: float, hidden_dim: int):
super().__init__()
self.save_hyperparameters()
```
### 3. Property Methods
```python
class MyModel(L.LightningModule):
@property
def learning_rate(self) -> float:
"""Current learning rate."""
return self.hparams.learning_rate
@property
def num_parameters(self) -> int:
"""Total number of parameters."""
return sum(p.numel() for p in self.parameters())
```
## Common Pitfalls
### 1. Forgetting to Return Loss
**Bad:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss)
# FORGOT TO RETURN LOSS!
```
**Good:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss)
return loss # MUST return loss
```
### 2. Not Syncing Metrics in DDP
**Bad:**
```python
def validation_step(self, batch, batch_idx):
self.log("val_acc", acc) # Wrong value with multi-GPU!
```
**Good:**
```python
def validation_step(self, batch, batch_idx):
self.log("val_acc", acc, sync_dist=True) # Correct aggregation
```
### 3. Manual Device Management
**Bad:**
```python
def training_step(self, batch, batch_idx):
x = x.cuda() # Don't do this
y = y.cuda()
```
**Good:**
```python
def training_step(self, batch, batch_idx):
# Lightning handles device placement
x, y = batch # Already on correct device
```
### 4. Not Using `self.log()`
**Bad:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.training_losses.append(loss) # Manual tracking
return loss
```
**Good:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss) # Automatic logging
return loss
```
### 5. Modifying Batch In-Place
**Bad:**
```python
def training_step(self, batch, batch_idx):
x, y = batch
x[:] = self.augment(x) # In-place modification can cause issues
```
**Good:**
```python
def training_step(self, batch, batch_idx):
x, y = batch
x = self.augment(x) # Create new tensor
```
## Performance Tips
### 1. Use DataLoader Workers
```python
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=32,
num_workers=4, # Use multiple workers
pin_memory=True, # Faster GPU transfer
persistent_workers=True # Keep workers alive
)
```
### 2. Enable Benchmark Mode (if fixed input size)
```python
trainer = L.Trainer(benchmark=True)
```
### 3. Use Automatic Batch Size Finding
```python
from lightning.pytorch.tuner import Tuner
trainer = L.Trainer()
tuner = Tuner(trainer)
# Find optimal batch size
tuner.scale_batch_size(model, datamodule=dm, mode="power")
# Then train
trainer.fit(model, datamodule=dm)
```
### 4. Optimize Data Loading
```python
# Use faster image decoding
import torch
import torchvision.transforms as T
transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Use PIL-SIMD for faster image loading
# pip install pillow-simd
```

View File

@@ -0,0 +1,564 @@
# Callbacks - Comprehensive Guide
## Overview
Callbacks enable adding arbitrary self-contained programs to training without cluttering your LightningModule research code. They execute custom logic at specific hooks during the training lifecycle.
## Architecture
Lightning organizes training logic across three components:
- **Trainer** - Engineering infrastructure
- **LightningModule** - Research code
- **Callbacks** - Non-essential functionality (monitoring, checkpointing, custom behaviors)
## Creating Custom Callbacks
Basic structure:
```python
from lightning.pytorch.callbacks import Callback
class MyCustomCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting!")
def on_train_end(self, trainer, pl_module):
print("Training is done!")
# Use with Trainer
trainer = L.Trainer(callbacks=[MyCustomCallback()])
```
## Built-in Callbacks
### ModelCheckpoint
Save models based on monitored metrics.
**Key Parameters:**
- `dirpath` - Directory to save checkpoints
- `filename` - Checkpoint filename pattern
- `monitor` - Metric to monitor
- `mode` - "min" or "max" for monitored metric
- `save_top_k` - Number of best models to keep
- `save_last` - Save last epoch checkpoint
- `every_n_epochs` - Save every N epochs
- `save_on_train_epoch_end` - Save at train epoch end vs validation end
**Examples:**
```python
from lightning.pytorch.callbacks import ModelCheckpoint
# Save top 3 models based on validation loss
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="model-{epoch:02d}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3,
save_last=True
)
# Save every 10 epochs
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="model-{epoch:02d}",
every_n_epochs=10,
save_top_k=-1 # Save all
)
# Save best model based on accuracy
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="best-model",
monitor="val_acc",
mode="max",
save_top_k=1
)
trainer = L.Trainer(callbacks=[checkpoint_callback])
```
**Accessing Saved Checkpoints:**
```python
# Get best model path
best_model_path = checkpoint_callback.best_model_path
# Get last checkpoint path
last_checkpoint = checkpoint_callback.last_model_path
# Get all checkpoint paths
all_checkpoints = checkpoint_callback.best_k_models
```
### EarlyStopping
Stop training when a monitored metric stops improving.
**Key Parameters:**
- `monitor` - Metric to monitor
- `patience` - Number of epochs with no improvement after which training stops
- `mode` - "min" or "max" for monitored metric
- `min_delta` - Minimum change to qualify as an improvement
- `verbose` - Print messages
- `strict` - Crash if monitored metric not found
**Examples:**
```python
from lightning.pytorch.callbacks import EarlyStopping
# Stop when validation loss stops improving
early_stop = EarlyStopping(
monitor="val_loss",
patience=10,
mode="min",
verbose=True
)
# Stop when accuracy plateaus
early_stop = EarlyStopping(
monitor="val_acc",
patience=5,
mode="max",
min_delta=0.001 # Must improve by at least 0.001
)
trainer = L.Trainer(callbacks=[early_stop])
```
### LearningRateMonitor
Track learning rate changes from schedulers.
**Key Parameters:**
- `logging_interval` - When to log: "step" or "epoch"
- `log_momentum` - Also log momentum values
**Example:**
```python
from lightning.pytorch.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = L.Trainer(callbacks=[lr_monitor])
# Logs learning rate automatically as "lr-{optimizer_name}"
```
### DeviceStatsMonitor
Log device performance metrics (GPU/CPU/TPU).
**Key Parameters:**
- `cpu_stats` - Log CPU stats
**Example:**
```python
from lightning.pytorch.callbacks import DeviceStatsMonitor
device_stats = DeviceStatsMonitor(cpu_stats=True)
trainer = L.Trainer(callbacks=[device_stats])
# Logs: gpu_utilization, gpu_memory_usage, etc.
```
### ModelSummary / RichModelSummary
Display model architecture and parameter count.
**Example:**
```python
from lightning.pytorch.callbacks import ModelSummary, RichModelSummary
# Basic summary
summary = ModelSummary(max_depth=2)
# Rich formatted summary (prettier)
rich_summary = RichModelSummary(max_depth=3)
trainer = L.Trainer(callbacks=[rich_summary])
```
### Timer
Track and limit training duration.
**Key Parameters:**
- `duration` - Maximum training time (timedelta or dict)
- `interval` - Check interval: "step", "epoch", or "batch"
**Example:**
```python
from lightning.pytorch.callbacks import Timer
from datetime import timedelta
# Limit training to 1 hour
timer = Timer(duration=timedelta(hours=1))
# Or using dict
timer = Timer(duration={"hours": 23, "minutes": 30})
trainer = L.Trainer(callbacks=[timer])
```
### BatchSizeFinder
Automatically find the optimal batch size.
**Example:**
```python
from lightning.pytorch.callbacks import BatchSizeFinder
batch_finder = BatchSizeFinder(mode="power", steps_per_trial=3)
trainer = L.Trainer(callbacks=[batch_finder])
trainer.fit(model, datamodule=dm)
# Optimal batch size is set automatically
```
### GradientAccumulationScheduler
Schedule gradient accumulation steps dynamically.
**Example:**
```python
from lightning.pytorch.callbacks import GradientAccumulationScheduler
# Accumulate 4 batches for first 5 epochs, then 2 batches
accumulator = GradientAccumulationScheduler(scheduling={0: 4, 5: 2})
trainer = L.Trainer(callbacks=[accumulator])
```
### StochasticWeightAveraging (SWA)
Apply stochastic weight averaging for better generalization.
**Example:**
```python
from lightning.pytorch.callbacks import StochasticWeightAveraging
swa = StochasticWeightAveraging(swa_lrs=1e-2, swa_epoch_start=0.8)
trainer = L.Trainer(callbacks=[swa])
```
## Custom Callback Examples
### Simple Logging Callback
```python
class MetricsLogger(Callback):
def __init__(self):
self.metrics = []
def on_validation_end(self, trainer, pl_module):
# Access logged metrics
metrics = trainer.callback_metrics
self.metrics.append(dict(metrics))
print(f"Validation metrics: {metrics}")
```
### Gradient Monitoring Callback
```python
class GradientMonitor(Callback):
def on_after_backward(self, trainer, pl_module):
# Log gradient norms
for name, param in pl_module.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
pl_module.log(f"grad_norm/{name}", grad_norm)
```
### Custom Checkpointing Callback
```python
class CustomCheckpoint(Callback):
def __init__(self, save_dir):
self.save_dir = save_dir
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
if epoch % 5 == 0: # Save every 5 epochs
filepath = f"{self.save_dir}/custom-{epoch}.ckpt"
trainer.save_checkpoint(filepath)
print(f"Saved checkpoint: {filepath}")
```
### Model Freezing Callback
```python
class FreezeUnfreeze(Callback):
def __init__(self, freeze_until_epoch=10):
self.freeze_until_epoch = freeze_until_epoch
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
if epoch < self.freeze_until_epoch:
# Freeze backbone
for param in pl_module.backbone.parameters():
param.requires_grad = False
else:
# Unfreeze backbone
for param in pl_module.backbone.parameters():
param.requires_grad = True
```
### Learning Rate Finder Callback
```python
class LRFinder(Callback):
def __init__(self, min_lr=1e-5, max_lr=1e-1, num_steps=100):
self.min_lr = min_lr
self.max_lr = max_lr
self.num_steps = num_steps
self.lrs = []
self.losses = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx >= self.num_steps:
trainer.should_stop = True
return
# Exponential LR schedule
lr = self.min_lr * (self.max_lr / self.min_lr) ** (batch_idx / self.num_steps)
optimizer = trainer.optimizers[0]
for param_group in optimizer.param_groups:
param_group['lr'] = lr
self.lrs.append(lr)
self.losses.append(outputs['loss'].item())
def on_train_end(self, trainer, pl_module):
# Plot LR vs Loss
import matplotlib.pyplot as plt
plt.plot(self.lrs, self.losses)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.savefig('lr_finder.png')
```
### Prediction Saver Callback
```python
class PredictionSaver(Callback):
def __init__(self, save_path):
self.save_path = save_path
self.predictions = []
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.predictions.append(outputs)
def on_predict_end(self, trainer, pl_module):
# Save all predictions
torch.save(self.predictions, self.save_path)
print(f"Predictions saved to {self.save_path}")
```
## Available Hooks
### Setup and Teardown
- `setup(trainer, pl_module, stage)` - Called at beginning of fit/test/predict
- `teardown(trainer, pl_module, stage)` - Called at end of fit/test/predict
### Training Lifecycle
- `on_fit_start(trainer, pl_module)` - Called at start of fit
- `on_fit_end(trainer, pl_module)` - Called at end of fit
- `on_train_start(trainer, pl_module)` - Called at start of training
- `on_train_end(trainer, pl_module)` - Called at end of training
### Epoch Boundaries
- `on_train_epoch_start(trainer, pl_module)` - Called at start of training epoch
- `on_train_epoch_end(trainer, pl_module)` - Called at end of training epoch
- `on_validation_epoch_start(trainer, pl_module)` - Called at start of validation
- `on_validation_epoch_end(trainer, pl_module)` - Called at end of validation
- `on_test_epoch_start(trainer, pl_module)` - Called at start of test
- `on_test_epoch_end(trainer, pl_module)` - Called at end of test
### Batch Boundaries
- `on_train_batch_start(trainer, pl_module, batch, batch_idx)` - Before training batch
- `on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)` - After training batch
- `on_validation_batch_start(trainer, pl_module, batch, batch_idx)` - Before validation batch
- `on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx)` - After validation batch
### Gradient Events
- `on_before_backward(trainer, pl_module, loss)` - Before loss.backward()
- `on_after_backward(trainer, pl_module)` - After loss.backward()
- `on_before_optimizer_step(trainer, pl_module, optimizer)` - Before optimizer.step()
### Checkpoint Events
- `on_save_checkpoint(trainer, pl_module, checkpoint)` - When saving checkpoint
- `on_load_checkpoint(trainer, pl_module, checkpoint)` - When loading checkpoint
### Exception Handling
- `on_exception(trainer, pl_module, exception)` - When exception occurs
## State Management
For callbacks requiring persistence across checkpoints:
```python
class StatefulCallback(Callback):
def __init__(self):
self.counter = 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.counter += 1
def state_dict(self):
return {"counter": self.counter}
def load_state_dict(self, state_dict):
self.counter = state_dict["counter"]
@property
def state_key(self):
# Unique identifier for this callback
return "my_stateful_callback"
```
## Best Practices
### 1. Keep Callbacks Isolated
Each callback should be self-contained and independent:
```python
# Good: Self-contained
class MyCallback(Callback):
def __init__(self):
self.data = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.data.append(outputs['loss'].item())
# Bad: Depends on external state
global_data = []
class BadCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
global_data.append(outputs['loss'].item()) # External dependency
```
### 2. Avoid Inter-Callback Dependencies
Callbacks should not depend on other callbacks:
```python
# Bad: Callback B depends on Callback A
class CallbackA(Callback):
def __init__(self):
self.value = 0
class CallbackB(Callback):
def __init__(self, callback_a):
self.callback_a = callback_a # Tight coupling
# Good: Independent callbacks
class CallbackA(Callback):
def __init__(self):
self.value = 0
class CallbackB(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Access trainer state instead
value = trainer.callback_metrics.get('metric')
```
### 3. Never Manually Invoke Callback Methods
Let Lightning call callbacks automatically:
```python
# Bad: Manual invocation
callback = MyCallback()
callback.on_train_start(trainer, model) # Don't do this
# Good: Let Trainer handle it
trainer = L.Trainer(callbacks=[MyCallback()])
```
### 4. Design for Any Execution Order
Callbacks may execute in any order, so don't rely on specific ordering:
```python
# Good: Order-independent
class GoodCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Use trainer state, not other callbacks
metrics = trainer.callback_metrics
self.log_metrics(metrics)
```
### 5. Use Callbacks for Non-Essential Logic
Keep core research code in LightningModule, use callbacks for auxiliary functionality:
```python
# Good separation
class MyModel(L.LightningModule):
# Core research logic here
def training_step(self, batch, batch_idx):
return loss
# Non-essential monitoring in callback
class MonitorCallback(Callback):
def on_validation_end(self, trainer, pl_module):
# Monitoring logic
pass
```
## Common Patterns
### Combining Multiple Callbacks
```python
from lightning.pytorch.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor,
DeviceStatsMonitor
)
callbacks = [
ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=3),
EarlyStopping(monitor="val_loss", patience=10, mode="min"),
LearningRateMonitor(logging_interval="step"),
DeviceStatsMonitor()
]
trainer = L.Trainer(callbacks=callbacks)
```
### Conditional Callback Activation
```python
class ConditionalCallback(Callback):
def __init__(self, activate_after_epoch=10):
self.activate_after_epoch = activate_after_epoch
def on_train_epoch_end(self, trainer, pl_module):
if trainer.current_epoch >= self.activate_after_epoch:
# Only active after specified epoch
self.do_something(trainer, pl_module)
```
### Multi-Stage Training Callback
```python
class MultiStageTraining(Callback):
def __init__(self, stage_epochs=[10, 20, 30]):
self.stage_epochs = stage_epochs
self.current_stage = 0
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
if epoch in self.stage_epochs:
self.current_stage += 1
print(f"Entering stage {self.current_stage}")
# Adjust learning rate for new stage
for optimizer in trainer.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
```

View File

@@ -0,0 +1,565 @@
# LightningDataModule - Comprehensive Guide
## Overview
A LightningDataModule is a reusable, shareable class that encapsulates all data processing steps in PyTorch Lightning. It solves the problem of scattered data preparation logic by standardizing how datasets are managed and shared across projects.
## Core Problem It Solves
In traditional PyTorch workflows, data handling is fragmented across multiple files, making it difficult to answer questions like:
- "What splits did you use?"
- "What transforms were applied?"
- "How was the data prepared?"
DataModules centralize this information for reproducibility and reusability.
## Five Processing Steps
A DataModule organizes data handling into five phases:
1. **Download/tokenize/process** - Initial data acquisition
2. **Clean and save** - Persist processed data to disk
3. **Load into Dataset** - Create PyTorch Dataset objects
4. **Apply transforms** - Data augmentation, normalization, etc.
5. **Wrap in DataLoader** - Configure batching and loading
## Main Methods
### `prepare_data()`
Downloads and processes data. Runs only once on a single process (not distributed).
**Use for:**
- Downloading datasets
- Tokenizing text
- Saving processed data to disk
**Important:** Do not set state here (e.g., self.x = y). State is not transferred to other processes.
**Example:**
```python
def prepare_data(self):
# Download data (runs once)
download_dataset("http://example.com/data.zip", "data/")
# Tokenize and save (runs once)
tokenize_and_save("data/raw/", "data/processed/")
```
### `setup(stage)`
Creates datasets and applies transforms. Runs on every process in distributed training.
**Parameters:**
- `stage` - 'fit', 'validate', 'test', or 'predict'
**Use for:**
- Creating train/val/test splits
- Building Dataset objects
- Applying transforms
- Setting state (self.train_dataset = ...)
**Example:**
```python
def setup(self, stage):
if stage == 'fit':
full_dataset = MyDataset("data/processed/")
self.train_dataset, self.val_dataset = random_split(
full_dataset, [0.8, 0.2]
)
if stage == 'test':
self.test_dataset = MyDataset("data/processed/test/")
if stage == 'predict':
self.predict_dataset = MyDataset("data/processed/predict/")
```
### `train_dataloader()`
Returns the training DataLoader.
**Example:**
```python
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True
)
```
### `val_dataloader()`
Returns the validation DataLoader(s).
**Example:**
```python
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
```
### `test_dataloader()`
Returns the test DataLoader(s).
**Example:**
```python
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers
)
```
### `predict_dataloader()`
Returns the prediction DataLoader(s).
**Example:**
```python
def predict_dataloader(self):
return DataLoader(
self.predict_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers
)
```
## Complete Example
```python
import lightning as L
from torch.utils.data import DataLoader, Dataset, random_split
import torch
class MyDataset(Dataset):
def __init__(self, data_path, transform=None):
self.data_path = data_path
self.transform = transform
self.data = self._load_data()
def _load_data(self):
# Load your data here
return torch.randn(1000, 3, 224, 224)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
if self.transform:
sample = self.transform(sample)
return sample
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir="./data", batch_size=32, num_workers=4):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
# Transforms
self.train_transform = self._get_train_transforms()
self.test_transform = self._get_test_transforms()
def _get_train_transforms(self):
# Define training transforms
return lambda x: x # Placeholder
def _get_test_transforms(self):
# Define test/val transforms
return lambda x: x # Placeholder
def prepare_data(self):
# Download data (runs once on single process)
# download_data(self.data_dir)
pass
def setup(self, stage=None):
# Create datasets (runs on every process)
if stage == 'fit' or stage is None:
full_dataset = MyDataset(
self.data_dir,
transform=self.train_transform
)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
self.train_dataset, self.val_dataset = random_split(
full_dataset, [train_size, val_size]
)
if stage == 'test' or stage is None:
self.test_dataset = MyDataset(
self.data_dir,
transform=self.test_transform
)
if stage == 'predict':
self.predict_dataset = MyDataset(
self.data_dir,
transform=self.test_transform
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True if self.num_workers > 0 else False
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True if self.num_workers > 0 else False
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers
)
def predict_dataloader(self):
return DataLoader(
self.predict_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers
)
```
## Usage
```python
# Create DataModule
dm = MyDataModule(data_dir="./data", batch_size=64, num_workers=8)
# Use with Trainer
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, datamodule=dm)
# Test
trainer.test(model, datamodule=dm)
# Predict
predictions = trainer.predict(model, datamodule=dm)
# Or use standalone in PyTorch
dm.prepare_data()
dm.setup(stage='fit')
train_loader = dm.train_dataloader()
for batch in train_loader:
# Your training code
pass
```
## Additional Hooks
### `transfer_batch_to_device(batch, device, dataloader_idx)`
Custom logic for moving batches to devices.
**Example:**
```python
def transfer_batch_to_device(self, batch, device, dataloader_idx):
# Custom transfer logic
if isinstance(batch, dict):
return {k: v.to(device) for k, v in batch.items()}
return super().transfer_batch_to_device(batch, device, dataloader_idx)
```
### `on_before_batch_transfer(batch, dataloader_idx)`
Augment or modify batch before transferring to device (runs on CPU).
**Example:**
```python
def on_before_batch_transfer(self, batch, dataloader_idx):
# Apply CPU-based augmentations
batch['image'] = apply_augmentation(batch['image'])
return batch
```
### `on_after_batch_transfer(batch, dataloader_idx)`
Augment or modify batch after transferring to device (runs on GPU).
**Example:**
```python
def on_after_batch_transfer(self, batch, dataloader_idx):
# Apply GPU-based augmentations
batch['image'] = gpu_augmentation(batch['image'])
return batch
```
### `state_dict()` / `load_state_dict(state_dict)`
Save and restore DataModule state for checkpointing.
**Example:**
```python
def state_dict(self):
return {"current_fold": self.current_fold}
def load_state_dict(self, state_dict):
self.current_fold = state_dict["current_fold"]
```
### `teardown(stage)`
Cleanup operations after training/testing/prediction.
**Example:**
```python
def teardown(self, stage):
# Clean up resources
if stage == 'fit':
self.train_dataset = None
self.val_dataset = None
```
## Advanced Patterns
### Multiple Validation/Test DataLoaders
Return a list or dictionary of DataLoaders:
```python
def val_dataloader(self):
return [
DataLoader(self.val_dataset_1, batch_size=32),
DataLoader(self.val_dataset_2, batch_size=32)
]
# Or with names (for logging)
def val_dataloader(self):
return {
"val_easy": DataLoader(self.val_easy, batch_size=32),
"val_hard": DataLoader(self.val_hard, batch_size=32)
}
# In LightningModule
def validation_step(self, batch, batch_idx, dataloader_idx=0):
if dataloader_idx == 0:
# Handle val_dataset_1
pass
else:
# Handle val_dataset_2
pass
```
### Cross-Validation
```python
class CrossValidationDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size, num_folds=5):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_folds = num_folds
self.current_fold = 0
def setup(self, stage=None):
full_dataset = MyDataset(self.data_dir)
fold_size = len(full_dataset) // self.num_folds
# Create fold indices
indices = list(range(len(full_dataset)))
val_start = self.current_fold * fold_size
val_end = val_start + fold_size
val_indices = indices[val_start:val_end]
train_indices = indices[:val_start] + indices[val_end:]
self.train_dataset = Subset(full_dataset, train_indices)
self.val_dataset = Subset(full_dataset, val_indices)
def set_fold(self, fold):
self.current_fold = fold
def state_dict(self):
return {"current_fold": self.current_fold}
def load_state_dict(self, state_dict):
self.current_fold = state_dict["current_fold"]
# Usage
dm = CrossValidationDataModule("./data", batch_size=32, num_folds=5)
for fold in range(5):
dm.set_fold(fold)
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, datamodule=dm)
```
### Hyperparameter Saving
```python
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size=32, num_workers=4):
super().__init__()
# Save hyperparameters
self.save_hyperparameters()
def setup(self, stage=None):
# Access via self.hparams
print(f"Batch size: {self.hparams.batch_size}")
```
## Best Practices
### 1. Separate prepare_data and setup
- `prepare_data()` - Downloads/processes (single process, no state)
- `setup()` - Creates datasets (every process, set state)
### 2. Use stage Parameter
Check the stage in `setup()` to avoid unnecessary work:
```python
def setup(self, stage):
if stage == 'fit':
# Only load train/val data when fitting
self.train_dataset = ...
self.val_dataset = ...
elif stage == 'test':
# Only load test data when testing
self.test_dataset = ...
```
### 3. Pin Memory for GPU Training
Enable `pin_memory=True` in DataLoaders for faster GPU transfer:
```python
def train_dataloader(self):
return DataLoader(..., pin_memory=True)
```
### 4. Use Persistent Workers
Prevent worker restarts between epochs:
```python
def train_dataloader(self):
return DataLoader(
...,
num_workers=4,
persistent_workers=True
)
```
### 5. Avoid Shuffle in Validation/Test
Never shuffle validation or test data:
```python
def val_dataloader(self):
return DataLoader(..., shuffle=False) # Never True
```
### 6. Make DataModules Reusable
Accept configuration parameters in `__init__`:
```python
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size, num_workers, augment=True):
super().__init__()
self.save_hyperparameters()
```
### 7. Document Data Structure
Add docstrings explaining data format and expectations:
```python
class MyDataModule(L.LightningDataModule):
"""
DataModule for XYZ dataset.
Data format: (image, label) tuples
- image: torch.Tensor of shape (C, H, W)
- label: int in range [0, num_classes)
Args:
data_dir: Path to data directory
batch_size: Batch size for dataloaders
num_workers: Number of data loading workers
"""
```
## Common Pitfalls
### 1. Setting State in prepare_data
**Wrong:**
```python
def prepare_data(self):
self.dataset = load_data() # State not transferred to other processes!
```
**Correct:**
```python
def prepare_data(self):
download_data() # Only download, no state
def setup(self, stage):
self.dataset = load_data() # Set state here
```
### 2. Not Using stage Parameter
**Inefficient:**
```python
def setup(self, stage):
self.train_dataset = load_train()
self.val_dataset = load_val()
self.test_dataset = load_test() # Loads even when just fitting
```
**Efficient:**
```python
def setup(self, stage):
if stage == 'fit':
self.train_dataset = load_train()
self.val_dataset = load_val()
elif stage == 'test':
self.test_dataset = load_test()
```
### 3. Forgetting to Return DataLoaders
**Wrong:**
```python
def train_dataloader(self):
DataLoader(self.train_dataset, ...) # Forgot return!
```
**Correct:**
```python
def train_dataloader(self):
return DataLoader(self.train_dataset, ...)
```
## Integration with Trainer
```python
# Initialize DataModule
dm = MyDataModule(data_dir="./data", batch_size=64)
# All data loading is handled by DataModule
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, datamodule=dm)
# DataModule handles validation too
trainer.validate(model, datamodule=dm)
# And testing
trainer.test(model, datamodule=dm)
# And prediction
predictions = trainer.predict(model, datamodule=dm)
```

View File

@@ -0,0 +1,643 @@
# Distributed Training - Comprehensive Guide
## Overview
PyTorch Lightning provides several strategies for training large models efficiently across multiple GPUs, nodes, and machines. Choose the right strategy based on model size and hardware configuration.
## Strategy Selection Guide
### When to Use Each Strategy
**Regular Training (Single Device)**
- Model size: Any size that fits in single GPU memory
- Use case: Prototyping, small models, debugging
**DDP (Distributed Data Parallel)**
- Model size: <500M parameters (e.g., ResNet50 ~80M parameters)
- When: Weights, activations, optimizer states, and gradients all fit in GPU memory
- Goal: Scale batch size and speed across multiple GPUs
- Best for: Most standard deep learning models
**FSDP (Fully Sharded Data Parallel)**
- Model size: 500M+ parameters (e.g., large transformers like BERT-Large, GPT)
- When: Model doesn't fit in single GPU memory
- Recommended for: Users new to model parallelism or migrating from DDP
- Features: Activation checkpointing, CPU parameter offloading
**DeepSpeed**
- Model size: 500M+ parameters
- When: Need cutting-edge features or already familiar with DeepSpeed
- Features: CPU/disk parameter offloading, distributed checkpoints, fine-grained control
- Trade-off: More complex configuration
## DDP (Distributed Data Parallel)
### Basic Usage
```python
# Single GPU
trainer = L.Trainer(accelerator="gpu", devices=1)
# Multi-GPU on single node (automatic DDP)
trainer = L.Trainer(accelerator="gpu", devices=4)
# Explicit DDP strategy
trainer = L.Trainer(strategy="ddp", accelerator="gpu", devices=4)
```
### Multi-Node DDP
```python
# On each node, run:
trainer = L.Trainer(
strategy="ddp",
accelerator="gpu",
devices=4, # GPUs per node
num_nodes=4 # Total nodes
)
```
### DDP Configuration
```python
from lightning.pytorch.strategies import DDPStrategy
trainer = L.Trainer(
strategy=DDPStrategy(
process_group_backend="nccl", # "nccl" for GPU, "gloo" for CPU
find_unused_parameters=False, # Set True if model has unused parameters
gradient_as_bucket_view=True # More memory efficient
),
accelerator="gpu",
devices=4
)
```
### DDP Spawn
Use when `ddp` causes issues (slower but more compatible):
```python
trainer = L.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4)
```
### Best Practices for DDP
1. **Batch size:** Multiply by number of GPUs
```python
# If using 4 GPUs, effective batch size = batch_size * 4
dm = MyDataModule(batch_size=32) # 32 * 4 = 128 effective batch size
```
2. **Learning rate:** Often scaled with batch size
```python
# Linear scaling rule
base_lr = 0.001
num_gpus = 4
lr = base_lr * num_gpus
```
3. **Synchronization:** Use `sync_dist=True` for metrics
```python
self.log("val_loss", loss, sync_dist=True)
```
4. **Rank-specific operations:** Use decorators for main process only
```python
from lightning.pytorch.utilities import rank_zero_only
@rank_zero_only
def save_results(self):
# Only runs on main process (rank 0)
torch.save(self.results, "results.pt")
```
## FSDP (Fully Sharded Data Parallel)
### Basic Usage
```python
trainer = L.Trainer(
strategy="fsdp",
accelerator="gpu",
devices=4
)
```
### FSDP Configuration
```python
from lightning.pytorch.strategies import FSDPStrategy
import torch.nn as nn
trainer = L.Trainer(
strategy=FSDPStrategy(
# Sharding strategy
sharding_strategy="FULL_SHARD", # or "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"
# Activation checkpointing (save memory)
activation_checkpointing_policy={nn.TransformerEncoderLayer},
# CPU offloading (save GPU memory, slower)
cpu_offload=False,
# Mixed precision
mixed_precision=True,
# Wrap policy (auto-wrap layers)
auto_wrap_policy=None
),
accelerator="gpu",
devices=8,
precision="bf16-mixed"
)
```
### Sharding Strategies
**FULL_SHARD (default)**
- Shards optimizer states, gradients, and parameters
- Maximum memory savings
- More communication overhead
**SHARD_GRAD_OP**
- Shards optimizer states and gradients only
- Parameters kept on all devices
- Less memory savings but faster
**NO_SHARD**
- No sharding (equivalent to DDP)
- For comparison or when sharding not needed
**HYBRID_SHARD**
- Combines FULL_SHARD within nodes and NO_SHARD across nodes
- Good for multi-node setups
### Activation Checkpointing
Trade computation for memory:
```python
from lightning.pytorch.strategies import FSDPStrategy
import torch.nn as nn
# Checkpoint specific layer types
trainer = L.Trainer(
strategy=FSDPStrategy(
activation_checkpointing_policy={
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer
}
)
)
```
### CPU Offloading
Offload parameters to CPU when not in use:
```python
trainer = L.Trainer(
strategy=FSDPStrategy(
cpu_offload=True # Slower but saves GPU memory
),
accelerator="gpu",
devices=4
)
```
### FSDP with Large Models
```python
from lightning.pytorch.strategies import FSDPStrategy
import torch.nn as nn
class LargeTransformer(L.LightningModule):
def __init__(self):
super().__init__()
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=4096, nhead=32),
num_layers=48
)
def configure_sharded_model(self):
# Called by FSDP to wrap model
pass
# Train
trainer = L.Trainer(
strategy=FSDPStrategy(
activation_checkpointing_policy={nn.TransformerEncoderLayer},
cpu_offload=False,
sharding_strategy="FULL_SHARD"
),
accelerator="gpu",
devices=8,
precision="bf16-mixed",
max_epochs=10
)
model = LargeTransformer()
trainer.fit(model, datamodule=dm)
```
## DeepSpeed
### Installation
```bash
pip install deepspeed
```
### Basic Usage
```python
trainer = L.Trainer(
strategy="deepspeed_stage_2", # or "deepspeed_stage_3"
accelerator="gpu",
devices=4,
precision="16-mixed"
)
```
### DeepSpeed Stages
**Stage 1: Optimizer State Sharding**
- Shards optimizer states
- Moderate memory savings
```python
trainer = L.Trainer(strategy="deepspeed_stage_1")
```
**Stage 2: Optimizer + Gradient Sharding**
- Shards optimizer states and gradients
- Good memory savings
```python
trainer = L.Trainer(strategy="deepspeed_stage_2")
```
**Stage 3: Full Model Sharding (ZeRO-3)**
- Shards optimizer states, gradients, and model parameters
- Maximum memory savings
- Can train very large models
```python
trainer = L.Trainer(strategy="deepspeed_stage_3")
```
**Stage 2 with Offloading**
- Offload to CPU or NVMe
```python
trainer = L.Trainer(strategy="deepspeed_stage_2_offload")
trainer = L.Trainer(strategy="deepspeed_stage_3_offload")
```
### DeepSpeed Configuration File
For fine-grained control:
```python
from lightning.pytorch.strategies import DeepSpeedStrategy
# Create config file: ds_config.json
config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
},
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_clipping": 1.0,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}
trainer = L.Trainer(
strategy=DeepSpeedStrategy(config=config),
accelerator="gpu",
devices=8,
precision="16-mixed"
)
```
### DeepSpeed Best Practices
1. **Use Stage 2 for models <10B parameters**
2. **Use Stage 3 for models >10B parameters**
3. **Enable offloading if GPU memory is insufficient**
4. **Tune `reduce_bucket_size` for communication efficiency**
## Comparison Table
| Feature | DDP | FSDP | DeepSpeed |
|---------|-----|------|-----------|
| Model Size | <500M params | 500M+ params | 500M+ params |
| Memory Efficiency | Low | High | Very High |
| Speed | Fastest | Fast | Fast |
| Setup Complexity | Simple | Medium | Complex |
| Offloading | No | CPU | CPU + Disk |
| Best For | Standard models | Large models | Very large models |
| Configuration | Minimal | Moderate | Extensive |
## Mixed Precision Training
Use mixed precision to speed up training and save memory:
```python
# FP16 mixed precision
trainer = L.Trainer(precision="16-mixed")
# BFloat16 mixed precision (A100, H100)
trainer = L.Trainer(precision="bf16-mixed")
# Full precision (default)
trainer = L.Trainer(precision="32-true")
# Double precision
trainer = L.Trainer(precision="64-true")
```
### Mixed Precision with Different Strategies
```python
# DDP + FP16
trainer = L.Trainer(
strategy="ddp",
accelerator="gpu",
devices=4,
precision="16-mixed"
)
# FSDP + BFloat16
trainer = L.Trainer(
strategy="fsdp",
accelerator="gpu",
devices=8,
precision="bf16-mixed"
)
# DeepSpeed + FP16
trainer = L.Trainer(
strategy="deepspeed_stage_2",
accelerator="gpu",
devices=4,
precision="16-mixed"
)
```
## Multi-Node Training
### SLURM
```bash
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --gpus-per-node=4
#SBATCH --time=24:00:00
srun python train.py
```
```python
# train.py
trainer = L.Trainer(
strategy="ddp",
accelerator="gpu",
devices=4,
num_nodes=4
)
```
### Manual Multi-Node Setup
Node 0 (master):
```bash
python train.py --num_nodes=2 --node_rank=0 --master_addr=192.168.1.1 --master_port=12345
```
Node 1:
```bash
python train.py --num_nodes=2 --node_rank=1 --master_addr=192.168.1.1 --master_port=12345
```
```python
# train.py
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--node_rank", type=int, default=0)
parser.add_argument("--master_addr", type=str, default="localhost")
parser.add_argument("--master_port", type=int, default=12345)
args = parser.parse_args()
trainer = L.Trainer(
strategy="ddp",
accelerator="gpu",
devices=4,
num_nodes=args.num_nodes
)
```
## Common Patterns
### Gradient Accumulation with DDP
```python
# Simulate larger batch size
trainer = L.Trainer(
strategy="ddp",
accelerator="gpu",
devices=4,
accumulate_grad_batches=4 # Effective batch size = batch_size * devices * 4
)
```
### Model Checkpoint with Distributed Training
```python
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
save_top_k=3,
mode="min"
)
trainer = L.Trainer(
strategy="ddp",
accelerator="gpu",
devices=4,
callbacks=[checkpoint_callback]
)
```
### Reproducibility in Distributed Training
```python
import lightning as L
L.seed_everything(42, workers=True)
trainer = L.Trainer(
strategy="ddp",
accelerator="gpu",
devices=4,
deterministic=True
)
```
## Troubleshooting
### NCCL Timeout
Increase timeout for slow networks:
```python
import os
os.environ["NCCL_TIMEOUT"] = "3600" # 1 hour
trainer = L.Trainer(strategy="ddp", accelerator="gpu", devices=4)
```
### CUDA Out of Memory
Solutions:
1. Enable gradient checkpointing
2. Reduce batch size
3. Use FSDP or DeepSpeed
4. Enable CPU offloading
5. Use mixed precision
```python
# Option 1: Gradient checkpointing
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = MyTransformer()
self.model.gradient_checkpointing_enable()
# Option 2: Smaller batch size
dm = MyDataModule(batch_size=16) # Reduce from 32
# Option 3: FSDP with offloading
trainer = L.Trainer(
strategy=FSDPStrategy(cpu_offload=True),
precision="bf16-mixed"
)
# Option 4: Gradient accumulation
trainer = L.Trainer(accumulate_grad_batches=4)
```
### Distributed Sampler Issues
Lightning handles DistributedSampler automatically:
```python
# Don't do this
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(dataset) # Lightning does this automatically
# Just use shuffle
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
```
### Communication Overhead
Reduce communication with larger `find_unused_parameters`:
```python
trainer = L.Trainer(
strategy=DDPStrategy(find_unused_parameters=False),
accelerator="gpu",
devices=4
)
```
## Best Practices
### 1. Start with Single GPU
Test your code on single GPU before scaling:
```python
# Debug on single GPU
trainer = L.Trainer(accelerator="gpu", devices=1, fast_dev_run=True)
# Then scale to multiple GPUs
trainer = L.Trainer(accelerator="gpu", devices=4, strategy="ddp")
```
### 2. Use Appropriate Strategy
- <500M params: Use DDP
- 500M-10B params: Use FSDP
- >10B params: Use DeepSpeed Stage 3
### 3. Enable Mixed Precision
Always use mixed precision for modern GPUs:
```python
trainer = L.Trainer(precision="bf16-mixed") # A100, H100
trainer = L.Trainer(precision="16-mixed") # V100, T4
```
### 4. Scale Hyperparameters
Adjust learning rate and batch size when scaling:
```python
# Linear scaling rule
lr = base_lr * num_gpus
```
### 5. Sync Metrics
Always sync metrics in distributed training:
```python
self.log("val_loss", loss, sync_dist=True)
```
### 6. Use Rank-Zero Operations
File I/O and expensive operations on main process only:
```python
from lightning.pytorch.utilities import rank_zero_only
@rank_zero_only
def save_predictions(self):
torch.save(self.predictions, "predictions.pt")
```
### 7. Checkpoint Regularly
Save checkpoints to resume from failures:
```python
checkpoint_callback = ModelCheckpoint(
save_top_k=3,
save_last=True, # Always save last for resuming
every_n_epochs=5
)
```

View File

@@ -0,0 +1,487 @@
# LightningModule - Comprehensive Guide
## Overview
A `LightningModule` organizes PyTorch code into six logical sections without abstraction. The code remains pure PyTorch, just better organized. The Trainer handles device management, distributed sampling, and infrastructure while preserving full model control.
## Core Structure
```python
import lightning as L
import torch
import torch.nn.functional as F
class MyModel(L.LightningModule):
def __init__(self, learning_rate=0.001):
super().__init__()
self.save_hyperparameters() # Save init arguments
self.model = YourNeuralNetwork()
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("val_loss", loss)
self.log("val_acc", acc)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("test_loss", loss)
self.log("test_acc", acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss"
}
}
```
## Essential Methods
### Training Pipeline Methods
#### `training_step(batch, batch_idx)`
Computes the forward pass and returns the loss. Lightning automatically handles backward propagation and optimizer updates in automatic optimization mode.
**Parameters:**
- `batch` - Current training batch from the DataLoader
- `batch_idx` - Index of the current batch
**Returns:** Loss tensor (scalar) or dictionary with 'loss' key
**Example:**
```python
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
# Log training metrics
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log("learning_rate", self.optimizers().param_groups[0]['lr'])
return loss
```
#### `validation_step(batch, batch_idx)`
Evaluates the model on validation data. Runs with gradients disabled and model in eval mode automatically.
**Parameters:**
- `batch` - Current validation batch
- `batch_idx` - Index of the current batch
**Returns:** Optional - Loss or dictionary of metrics
**Example:**
```python
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
# Lightning aggregates across validation batches automatically
self.log("val_loss", loss, prog_bar=True)
return loss
```
#### `test_step(batch, batch_idx)`
Evaluates the model on test data. Only run when explicitly called with `trainer.test()`. Use after training is complete, typically before publication.
**Parameters:**
- `batch` - Current test batch
- `batch_idx` - Index of the current batch
**Returns:** Optional - Loss or dictionary of metrics
#### `predict_step(batch, batch_idx, dataloader_idx=0)`
Runs inference on data. Called when using `trainer.predict()`.
**Parameters:**
- `batch` - Current batch
- `batch_idx` - Index of the current batch
- `dataloader_idx` - Index of dataloader (if multiple)
**Returns:** Predictions (any format you need)
**Example:**
```python
def predict_step(self, batch, batch_idx):
x, y = batch
return self.model(x) # Return raw predictions
```
### Configuration Methods
#### `configure_optimizers()`
Returns optimizer(s) and optional learning rate scheduler(s).
**Return formats:**
1. **Single optimizer:**
```python
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
```
2. **Optimizer + scheduler:**
```python
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
return [optimizer], [scheduler]
```
3. **Advanced configuration with scheduler monitoring:**
```python
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss", # Metric to monitor
"interval": "epoch", # When to update (epoch/step)
"frequency": 1, # How often to update
"strict": True # Crash if monitored metric not found
}
}
```
4. **Multiple optimizers (for GANs, etc.):**
```python
def configure_optimizers(self):
opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
return [opt_g, opt_d]
```
#### `forward(*args, **kwargs)`
Standard PyTorch forward method. Use for inference or as part of training_step.
**Example:**
```python
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x) # Uses forward()
return F.mse_loss(y_hat, y)
```
### Logging and Metrics
#### `log(name, value, **kwargs)`
Records metrics with automatic epoch-level reduction across devices.
**Key parameters:**
- `name` - Metric name (string)
- `value` - Metric value (tensor or number)
- `on_step` - Log at current step (default: True in training_step, False otherwise)
- `on_epoch` - Log at epoch end (default: False in training_step, True otherwise)
- `prog_bar` - Display in progress bar (default: False)
- `logger` - Send to logger backends (default: True)
- `reduce_fx` - Reduction function: "mean", "sum", "max", "min" (default: "mean")
- `sync_dist` - Synchronize across devices in distributed training (default: False)
**Examples:**
```python
# Simple logging
self.log("train_loss", loss)
# Display in progress bar
self.log("accuracy", acc, prog_bar=True)
# Log per-step and per-epoch
self.log("loss", loss, on_step=True, on_epoch=True)
# Custom reduction for distributed training
self.log("batch_size", batch.size(0), reduce_fx="sum", sync_dist=True)
```
#### `log_dict(dictionary, **kwargs)`
Log multiple metrics simultaneously.
**Example:**
```python
metrics = {"train_loss": loss, "train_acc": acc, "learning_rate": lr}
self.log_dict(metrics, on_step=True, on_epoch=True)
```
#### `save_hyperparameters(*args, **kwargs)`
Stores initialization arguments for reproducibility and checkpoint restoration. Call in `__init__()`.
**Example:**
```python
def __init__(self, learning_rate, hidden_dim, dropout):
super().__init__()
self.save_hyperparameters() # Saves all init args
# Access via self.hparams.learning_rate, self.hparams.hidden_dim, etc.
```
## Key Properties
| Property | Description |
|----------|-------------|
| `self.current_epoch` | Current epoch number (0-indexed) |
| `self.global_step` | Total optimizer steps across all epochs |
| `self.device` | Current device (cuda:0, cpu, etc.) |
| `self.global_rank` | Process rank in distributed training (0 for main) |
| `self.local_rank` | GPU rank on current node |
| `self.hparams` | Saved hyperparameters (via save_hyperparameters) |
| `self.trainer` | Reference to parent Trainer instance |
| `self.automatic_optimization` | Whether to use automatic optimization (default: True) |
## Manual Optimization
For advanced use cases (GANs, reinforcement learning, multiple optimizers), disable automatic optimization:
```python
class GANModel(L.LightningModule):
def __init__(self):
super().__init__()
self.automatic_optimization = False
self.generator = Generator()
self.discriminator = Discriminator()
def training_step(self, batch, batch_idx):
opt_g, opt_d = self.optimizers()
# Train generator
opt_g.zero_grad()
g_loss = self.compute_generator_loss(batch)
self.manual_backward(g_loss)
opt_g.step()
# Train discriminator
opt_d.zero_grad()
d_loss = self.compute_discriminator_loss(batch)
self.manual_backward(d_loss)
opt_d.step()
self.log_dict({"g_loss": g_loss, "d_loss": d_loss})
def configure_optimizers(self):
opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
return [opt_g, opt_d]
```
## Important Lifecycle Hooks
### Setup and Teardown
#### `setup(stage)`
Called at the beginning of fit, validate, test, or predict. Useful for stage-specific setup.
**Parameters:**
- `stage` - 'fit', 'validate', 'test', or 'predict'
**Example:**
```python
def setup(self, stage):
if stage == 'fit':
# Setup training-specific components
self.train_dataset = load_train_data()
elif stage == 'test':
# Setup test-specific components
self.test_dataset = load_test_data()
```
#### `teardown(stage)`
Called at the end of fit, validate, test, or predict. Cleanup resources.
### Epoch Boundaries
#### `on_train_epoch_start()` / `on_train_epoch_end()`
Called at the beginning/end of each training epoch.
**Example:**
```python
def on_train_epoch_end(self):
# Compute epoch-level metrics
all_preds = torch.cat(self.training_step_outputs)
epoch_metric = compute_custom_metric(all_preds)
self.log("epoch_metric", epoch_metric)
self.training_step_outputs.clear() # Free memory
```
#### `on_validation_epoch_start()` / `on_validation_epoch_end()`
Called at the beginning/end of validation epoch.
#### `on_test_epoch_start()` / `on_test_epoch_end()`
Called at the beginning/end of test epoch.
### Gradient Hooks
#### `on_before_backward(loss)`
Called before loss.backward().
#### `on_after_backward()`
Called after loss.backward() but before optimizer step.
**Example - Gradient inspection:**
```python
def on_after_backward(self):
# Log gradient norms
grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
self.log("grad_norm", grad_norm)
```
### Checkpoint Hooks
#### `on_save_checkpoint(checkpoint)`
Customize checkpoint saving. Add extra state to save.
**Example:**
```python
def on_save_checkpoint(self, checkpoint):
checkpoint['custom_state'] = self.custom_data
```
#### `on_load_checkpoint(checkpoint)`
Customize checkpoint loading. Restore extra state.
**Example:**
```python
def on_load_checkpoint(self, checkpoint):
self.custom_data = checkpoint.get('custom_state', default_value)
```
## Best Practices
### 1. Device Agnosticism
Never use explicit `.cuda()` or `.cpu()` calls. Lightning handles device placement automatically.
**Bad:**
```python
x = x.cuda()
model = model.cuda()
```
**Good:**
```python
x = x.to(self.device) # Inside LightningModule
# Or let Lightning handle it automatically
```
### 2. Distributed Training Safety
Don't manually create `DistributedSampler`. Lightning handles this automatically.
**Bad:**
```python
sampler = DistributedSampler(dataset)
DataLoader(dataset, sampler=sampler)
```
**Good:**
```python
DataLoader(dataset, shuffle=True) # Lightning converts to DistributedSampler
```
### 3. Metric Aggregation
Use `self.log()` for automatic cross-device reduction rather than manual collection.
**Bad:**
```python
self.validation_outputs.append(loss)
def on_validation_epoch_end(self):
avg_loss = torch.stack(self.validation_outputs).mean()
```
**Good:**
```python
self.log("val_loss", loss) # Automatic aggregation
```
### 4. Hyperparameter Tracking
Always use `self.save_hyperparameters()` for easy model reloading.
**Example:**
```python
def __init__(self, learning_rate, hidden_dim):
super().__init__()
self.save_hyperparameters()
# Later: Load from checkpoint
model = MyModel.load_from_checkpoint("checkpoint.ckpt")
print(model.hparams.learning_rate)
```
### 5. Validation Placement
Run validation on a single device to ensure each sample is evaluated exactly once. Lightning handles this automatically with proper strategy configuration.
## Loading from Checkpoint
```python
# Load model with saved hyperparameters
model = MyModel.load_from_checkpoint("path/to/checkpoint.ckpt")
# Override hyperparameters if needed
model = MyModel.load_from_checkpoint(
"path/to/checkpoint.ckpt",
learning_rate=0.0001 # Override saved value
)
# Use for inference
model.eval()
predictions = model(data)
```
## Common Patterns
### Gradient Accumulation
Let Lightning handle gradient accumulation:
```python
trainer = L.Trainer(accumulate_grad_batches=4)
```
### Gradient Clipping
Configure in Trainer:
```python
trainer = L.Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="norm")
```
### Mixed Precision Training
Configure precision in Trainer:
```python
trainer = L.Trainer(precision="16-mixed") # or "bf16-mixed", "32-true"
```
### Learning Rate Warmup
Implement in configure_optimizers:
```python
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = {
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=0.01,
total_steps=self.trainer.estimated_stepping_batches
),
"interval": "step"
}
return [optimizer], [scheduler]
```

View File

@@ -0,0 +1,654 @@
# Logging - Comprehensive Guide
## Overview
PyTorch Lightning supports multiple logging integrations for experiment tracking and visualization. By default, Lightning uses TensorBoard, but you can easily switch to or combine multiple loggers.
## Supported Loggers
### TensorBoardLogger (Default)
Logs to local or remote file system in TensorBoard format.
**Installation:**
```bash
pip install tensorboard
```
**Usage:**
```python
from lightning.pytorch import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger(
save_dir="logs/",
name="my_model",
version="version_1",
default_hp_metric=False
)
trainer = L.Trainer(logger=tb_logger)
```
**View logs:**
```bash
tensorboard --logdir logs/
```
### WandbLogger
Weights & Biases integration for cloud-based experiment tracking.
**Installation:**
```bash
pip install wandb
```
**Usage:**
```python
from lightning.pytorch import loggers as pl_loggers
wandb_logger = pl_loggers.WandbLogger(
project="my-project",
name="experiment-1",
save_dir="logs/",
log_model=True # Log model checkpoints to W&B
)
trainer = L.Trainer(logger=wandb_logger)
```
**Features:**
- Cloud-based experiment tracking
- Model versioning
- Artifact management
- Collaborative features
- Hyperparameter sweeps
### MLFlowLogger
MLflow tracking integration.
**Installation:**
```bash
pip install mlflow
```
**Usage:**
```python
from lightning.pytorch import loggers as pl_loggers
mlflow_logger = pl_loggers.MLFlowLogger(
experiment_name="my_experiment",
tracking_uri="http://localhost:5000",
run_name="run_1"
)
trainer = L.Trainer(logger=mlflow_logger)
```
### CometLogger
Comet.ml experiment tracking.
**Installation:**
```bash
pip install comet-ml
```
**Usage:**
```python
from lightning.pytorch import loggers as pl_loggers
comet_logger = pl_loggers.CometLogger(
api_key="YOUR_API_KEY",
project_name="my-project",
experiment_name="experiment-1"
)
trainer = L.Trainer(logger=comet_logger)
```
### NeptuneLogger
Neptune.ai integration.
**Installation:**
```bash
pip install neptune
```
**Usage:**
```python
from lightning.pytorch import loggers as pl_loggers
neptune_logger = pl_loggers.NeptuneLogger(
api_key="YOUR_API_KEY",
project="username/project-name",
name="experiment-1"
)
trainer = L.Trainer(logger=neptune_logger)
```
### CSVLogger
Log to local file system in YAML and CSV format.
**Usage:**
```python
from lightning.pytorch import loggers as pl_loggers
csv_logger = pl_loggers.CSVLogger(
save_dir="logs/",
name="my_model",
version="1"
)
trainer = L.Trainer(logger=csv_logger)
```
**Output files:**
- `metrics.csv` - All logged metrics
- `hparams.yaml` - Hyperparameters
## Logging Metrics
### Basic Logging
Use `self.log()` within your LightningModule:
```python
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# Log metric
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
# Log multiple metrics
self.log("val_loss", loss)
self.log("val_acc", acc)
```
### Logging Parameters
#### `on_step` (bool)
Log at current step. Default: True in training_step, False otherwise.
```python
self.log("loss", loss, on_step=True)
```
#### `on_epoch` (bool)
Accumulate and log at epoch end. Default: False in training_step, True otherwise.
```python
self.log("loss", loss, on_epoch=True)
```
#### `prog_bar` (bool)
Display in progress bar. Default: False.
```python
self.log("train_loss", loss, prog_bar=True)
```
#### `logger` (bool)
Send to logger backends. Default: True.
```python
self.log("internal_metric", value, logger=False) # Don't log to external logger
```
#### `reduce_fx` (str or callable)
Reduction function: "mean", "sum", "max", "min". Default: "mean".
```python
self.log("batch_size", batch.size(0), reduce_fx="sum")
```
#### `sync_dist` (bool)
Synchronize metric across devices in distributed training. Default: False.
```python
self.log("loss", loss, sync_dist=True)
```
#### `rank_zero_only` (bool)
Only log from rank 0 process. Default: False.
```python
self.log("debug_metric", value, rank_zero_only=True)
```
### Complete Example
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Log per-step and per-epoch, display in progress bar
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
acc = self.compute_accuracy(batch)
# Log epoch-level metrics
self.log("val_loss", loss, on_epoch=True)
self.log("val_acc", acc, on_epoch=True, prog_bar=True)
```
### Logging Multiple Metrics
Use `log_dict()` to log multiple metrics at once:
```python
def training_step(self, batch, batch_idx):
loss, acc, f1 = self.compute_metrics(batch)
metrics = {
"train_loss": loss,
"train_acc": acc,
"train_f1": f1
}
self.log_dict(metrics, on_step=True, on_epoch=True)
return loss
```
## Logging Hyperparameters
### Automatic Hyperparameter Logging
Use `save_hyperparameters()` in your model:
```python
class MyModel(L.LightningModule):
def __init__(self, learning_rate, hidden_dim, dropout):
super().__init__()
# Automatically save and log hyperparameters
self.save_hyperparameters()
```
### Manual Hyperparameter Logging
```python
# In LightningModule
class MyModel(L.LightningModule):
def __init__(self, learning_rate):
super().__init__()
self.save_hyperparameters()
# Or manually with logger
trainer.logger.log_hyperparams({
"learning_rate": 0.001,
"batch_size": 32
})
```
## Logging Frequency
By default, Lightning logs every 50 training steps. Adjust with `log_every_n_steps`:
```python
trainer = L.Trainer(log_every_n_steps=10)
```
## Multiple Loggers
Use multiple loggers simultaneously:
```python
from lightning.pytorch import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger("logs/")
wandb_logger = pl_loggers.WandbLogger(project="my-project")
csv_logger = pl_loggers.CSVLogger("logs/")
trainer = L.Trainer(logger=[tb_logger, wandb_logger, csv_logger])
```
## Advanced Logging
### Logging Images
```python
import torchvision
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
# Log first batch of images once per epoch
if batch_idx == 0:
# Create image grid
grid = torchvision.utils.make_grid(x[:8])
# Log to TensorBoard
self.logger.experiment.add_image("val_images", grid, self.current_epoch)
# Log to Wandb
if isinstance(self.logger, pl_loggers.WandbLogger):
import wandb
self.logger.experiment.log({
"val_images": [wandb.Image(img) for img in x[:8]]
})
```
### Logging Histograms
```python
def on_train_epoch_end(self):
# Log parameter histograms
for name, param in self.named_parameters():
self.logger.experiment.add_histogram(name, param, self.current_epoch)
if param.grad is not None:
self.logger.experiment.add_histogram(
f"{name}_grad", param.grad, self.current_epoch
)
```
### Logging Model Graph
```python
def on_train_start(self):
# Log model architecture
sample_input = torch.randn(1, 3, 224, 224).to(self.device)
self.logger.experiment.add_graph(self.model, sample_input)
```
### Logging Custom Plots
```python
import matplotlib.pyplot as plt
def on_validation_epoch_end(self):
# Create custom plot
fig, ax = plt.subplots()
ax.plot(self.validation_losses)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
# Log to TensorBoard
self.logger.experiment.add_figure("loss_curve", fig, self.current_epoch)
plt.close(fig)
```
### Logging Text
```python
def validation_step(self, batch, batch_idx):
# Generate predictions
predictions = self.generate_text(batch)
# Log to TensorBoard
self.logger.experiment.add_text(
"predictions",
f"Batch {batch_idx}: {predictions}",
self.current_epoch
)
```
### Logging Audio
```python
def validation_step(self, batch, batch_idx):
audio = self.generate_audio(batch)
# Log to TensorBoard (audio is tensor of shape [1, samples])
self.logger.experiment.add_audio(
"generated_audio",
audio,
self.current_epoch,
sample_rate=22050
)
```
## Accessing Logger in LightningModule
```python
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# Access logger experiment object
logger = self.logger.experiment
# For TensorBoard
if isinstance(self.logger, pl_loggers.TensorBoardLogger):
logger.add_scalar("custom_metric", value, self.global_step)
# For Wandb
if isinstance(self.logger, pl_loggers.WandbLogger):
logger.log({"custom_metric": value})
# For MLflow
if isinstance(self.logger, pl_loggers.MLFlowLogger):
logger.log_metric("custom_metric", value)
```
## Custom Logger
Create a custom logger by inheriting from `Logger`:
```python
from lightning.pytorch.loggers import Logger
from lightning.pytorch.utilities import rank_zero_only
class MyCustomLogger(Logger):
def __init__(self, save_dir):
super().__init__()
self.save_dir = save_dir
self._name = "my_logger"
self._version = "0.1"
@property
def name(self):
return self._name
@property
def version(self):
return self._version
@rank_zero_only
def log_metrics(self, metrics, step):
# Log metrics to your backend
print(f"Step {step}: {metrics}")
@rank_zero_only
def log_hyperparams(self, params):
# Log hyperparameters
print(f"Hyperparameters: {params}")
@rank_zero_only
def save(self):
# Save logger state
pass
@rank_zero_only
def finalize(self, status):
# Cleanup when training ends
pass
# Usage
custom_logger = MyCustomLogger(save_dir="logs/")
trainer = L.Trainer(logger=custom_logger)
```
## Best Practices
### 1. Log Both Step and Epoch Metrics
```python
# Good: Track both granular and aggregate metrics
self.log("train_loss", loss, on_step=True, on_epoch=True)
```
### 2. Use Progress Bar for Key Metrics
```python
# Show important metrics in progress bar
self.log("val_acc", acc, prog_bar=True)
```
### 3. Synchronize Metrics in Distributed Training
```python
# Ensure correct aggregation across GPUs
self.log("val_loss", loss, sync_dist=True)
```
### 4. Log Learning Rate
```python
from lightning.pytorch.callbacks import LearningRateMonitor
trainer = L.Trainer(callbacks=[LearningRateMonitor(logging_interval="step")])
```
### 5. Log Gradient Norms
```python
def on_after_backward(self):
# Monitor gradient flow
grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=float('inf'))
self.log("grad_norm", grad_norm)
```
### 6. Use Descriptive Metric Names
```python
# Good: Clear naming convention
self.log("train/loss", loss)
self.log("train/accuracy", acc)
self.log("val/loss", val_loss)
self.log("val/accuracy", val_acc)
```
### 7. Log Hyperparameters
```python
# Always save hyperparameters for reproducibility
class MyModel(L.LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
```
### 8. Don't Log Too Frequently
```python
# Avoid logging every step for expensive operations
if batch_idx % 100 == 0:
self.log_images(batch)
```
## Common Patterns
### Structured Logging
```python
def training_step(self, batch, batch_idx):
loss, metrics = self.compute_loss_and_metrics(batch)
# Organize logs with prefixes
self.log("train/loss", loss)
self.log_dict({f"train/{k}": v for k, v in metrics.items()})
return loss
def validation_step(self, batch, batch_idx):
loss, metrics = self.compute_loss_and_metrics(batch)
self.log("val/loss", loss)
self.log_dict({f"val/{k}": v for k, v in metrics.items()})
```
### Conditional Logging
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Log expensive metrics less frequently
if self.global_step % 100 == 0:
expensive_metric = self.compute_expensive_metric(batch)
self.log("expensive_metric", expensive_metric)
self.log("train_loss", loss)
return loss
```
### Multi-Task Logging
```python
def training_step(self, batch, batch_idx):
x, y_task1, y_task2 = batch
loss_task1 = self.compute_task1_loss(x, y_task1)
loss_task2 = self.compute_task2_loss(x, y_task2)
total_loss = loss_task1 + loss_task2
# Log per-task metrics
self.log_dict({
"train/loss_task1": loss_task1,
"train/loss_task2": loss_task2,
"train/loss_total": total_loss
})
return total_loss
```
## Troubleshooting
### Metric Not Found Error
If you get "metric not found" errors with schedulers:
```python
# Make sure metric is logged with logger=True
self.log("val_loss", loss, logger=True)
# And configure scheduler to monitor it
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss" # Must match logged metric name
}
}
```
### Metrics Not Syncing in Distributed Training
```python
# Enable sync_dist for proper aggregation
self.log("val_acc", acc, sync_dist=True)
```
### Logger Not Saving
```python
# Ensure logger has write permissions
trainer = L.Trainer(
logger=pl_loggers.TensorBoardLogger("logs/"),
default_root_dir="outputs/" # Ensure directory exists and is writable
)
```

View File

@@ -0,0 +1,641 @@
# Trainer - Comprehensive Guide
## Overview
The Trainer automates training workflows after organizing PyTorch code into a LightningModule. It handles loop details, device management, callbacks, gradient operations, checkpointing, and distributed training automatically.
## Core Purpose
The Trainer manages:
- Automatically enabling/disabling gradients
- Running training, validation, and test dataloaders
- Calling callbacks at appropriate times
- Placing batches on correct devices
- Orchestrating distributed training
- Progress bars and logging
- Checkpointing and early stopping
## Main Methods
### `fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None)`
Runs the full training routine including optional validation.
**Parameters:**
- `model` - LightningModule to train
- `train_dataloaders` - Training DataLoader(s)
- `val_dataloaders` - Optional validation DataLoader(s)
- `datamodule` - Optional LightningDataModule (replaces dataloaders)
**Examples:**
```python
# With DataLoaders
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)
# With DataModule
trainer.fit(model, datamodule=dm)
# Continue training from checkpoint
trainer.fit(model, train_loader, ckpt_path="checkpoint.ckpt")
```
### `validate(model=None, dataloaders=None, datamodule=None)`
Run validation loop without training.
**Example:**
```python
trainer = L.Trainer()
trainer.validate(model, val_loader)
```
### `test(model=None, dataloaders=None, datamodule=None)`
Run test loop. Only use before publishing results.
**Example:**
```python
trainer = L.Trainer()
trainer.test(model, test_loader)
```
### `predict(model=None, dataloaders=None, datamodule=None)`
Run inference on data and return predictions.
**Example:**
```python
trainer = L.Trainer()
predictions = trainer.predict(model, predict_loader)
```
## Essential Parameters
### Training Duration
#### `max_epochs` (int)
Maximum number of epochs to train. Default: 1000
```python
trainer = L.Trainer(max_epochs=100)
```
#### `min_epochs` (int)
Minimum number of epochs to train. Default: None
```python
trainer = L.Trainer(min_epochs=10, max_epochs=100)
```
#### `max_steps` (int)
Maximum number of optimizer steps. Overrides max_epochs. Default: -1 (unlimited)
```python
trainer = L.Trainer(max_steps=10000)
```
#### `max_time` (str or dict)
Maximum training time. Useful for time-limited clusters.
```python
# String format
trainer = L.Trainer(max_time="00:12:00:00") # 12 hours
# Dictionary format
trainer = L.Trainer(max_time={"days": 1, "hours": 6})
```
### Hardware Configuration
#### `accelerator` (str or Accelerator)
Hardware to use: "cpu", "gpu", "tpu", "ipu", "hpu", "mps", or "auto". Default: "auto"
```python
trainer = L.Trainer(accelerator="gpu")
trainer = L.Trainer(accelerator="auto") # Auto-detect available hardware
```
#### `devices` (int, list, or str)
Number or list of device indices to use.
```python
# Use 2 GPUs
trainer = L.Trainer(devices=2, accelerator="gpu")
# Use specific GPUs
trainer = L.Trainer(devices=[0, 2], accelerator="gpu")
# Use all available devices
trainer = L.Trainer(devices="auto", accelerator="gpu")
# CPU with 4 cores
trainer = L.Trainer(devices=4, accelerator="cpu")
```
#### `strategy` (str or Strategy)
Distributed training strategy: "ddp", "ddp_spawn", "fsdp", "deepspeed", etc. Default: "auto"
```python
# Data Distributed Parallel
trainer = L.Trainer(strategy="ddp", accelerator="gpu", devices=4)
# Fully Sharded Data Parallel
trainer = L.Trainer(strategy="fsdp", accelerator="gpu", devices=4)
# DeepSpeed
trainer = L.Trainer(strategy="deepspeed_stage_2", accelerator="gpu", devices=4)
```
#### `precision` (str or int)
Floating point precision: "32-true", "16-mixed", "bf16-mixed", "64-true", etc.
```python
# Mixed precision (FP16)
trainer = L.Trainer(precision="16-mixed")
# BFloat16 mixed precision
trainer = L.Trainer(precision="bf16-mixed")
# Full precision
trainer = L.Trainer(precision="32-true")
```
### Optimization Configuration
#### `gradient_clip_val` (float)
Gradient clipping value. Default: None
```python
# Clip gradients by norm
trainer = L.Trainer(gradient_clip_val=0.5)
```
#### `gradient_clip_algorithm` (str)
Gradient clipping algorithm: "norm" or "value". Default: "norm"
```python
trainer = L.Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="norm")
```
#### `accumulate_grad_batches` (int or dict)
Accumulate gradients over N batches before optimizer step.
```python
# Accumulate over 4 batches
trainer = L.Trainer(accumulate_grad_batches=4)
# Different accumulation per epoch
trainer = L.Trainer(accumulate_grad_batches={0: 4, 5: 2, 10: 1})
```
### Validation Configuration
#### `check_val_every_n_epoch` (int)
Run validation every N epochs. Default: 1
```python
trainer = L.Trainer(check_val_every_n_epoch=10)
```
#### `val_check_interval` (int or float)
How often to check validation within a training epoch.
```python
# Check validation every 0.25 of training epoch
trainer = L.Trainer(val_check_interval=0.25)
# Check validation every 100 training batches
trainer = L.Trainer(val_check_interval=100)
```
#### `limit_val_batches` (int or float)
Limit validation batches.
```python
# Use only 10% of validation data
trainer = L.Trainer(limit_val_batches=0.1)
# Use only 50 validation batches
trainer = L.Trainer(limit_val_batches=50)
# Disable validation
trainer = L.Trainer(limit_val_batches=0)
```
#### `num_sanity_val_steps` (int)
Number of validation batches to run before training starts. Default: 2
```python
# Skip sanity check
trainer = L.Trainer(num_sanity_val_steps=0)
# Run 5 sanity validation steps
trainer = L.Trainer(num_sanity_val_steps=5)
```
### Logging and Progress
#### `logger` (Logger or list or bool)
Logger(s) to use for experiment tracking.
```python
from lightning.pytorch import loggers as pl_loggers
# TensorBoard logger
tb_logger = pl_loggers.TensorBoardLogger("logs/")
trainer = L.Trainer(logger=tb_logger)
# Multiple loggers
wandb_logger = pl_loggers.WandbLogger(project="my-project")
trainer = L.Trainer(logger=[tb_logger, wandb_logger])
# Disable logging
trainer = L.Trainer(logger=False)
```
#### `log_every_n_steps` (int)
How often to log within training steps. Default: 50
```python
trainer = L.Trainer(log_every_n_steps=10)
```
#### `enable_progress_bar` (bool)
Show progress bar. Default: True
```python
trainer = L.Trainer(enable_progress_bar=False)
```
### Callbacks
#### `callbacks` (list)
List of callbacks to use during training.
```python
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
save_top_k=3,
mode="min"
)
early_stop_callback = EarlyStopping(
monitor="val_loss",
patience=5,
mode="min"
)
trainer = L.Trainer(callbacks=[checkpoint_callback, early_stop_callback])
```
### Checkpointing
#### `default_root_dir` (str)
Default directory for logs and checkpoints. Default: current working directory
```python
trainer = L.Trainer(default_root_dir="./experiments/")
```
#### `enable_checkpointing` (bool)
Enable automatic checkpointing. Default: True
```python
trainer = L.Trainer(enable_checkpointing=True)
```
### Debugging
#### `fast_dev_run` (bool or int)
Run a single batch (or N batches) through train/val/test for debugging.
```python
# Run 1 batch of train/val/test
trainer = L.Trainer(fast_dev_run=True)
# Run 5 batches of train/val/test
trainer = L.Trainer(fast_dev_run=5)
```
#### `limit_train_batches` (int or float)
Limit training batches.
```python
# Use only 25% of training data
trainer = L.Trainer(limit_train_batches=0.25)
# Use only 100 training batches
trainer = L.Trainer(limit_train_batches=100)
```
#### `limit_test_batches` (int or float)
Limit test batches.
```python
trainer = L.Trainer(limit_test_batches=0.5)
```
#### `overfit_batches` (int or float)
Overfit on a subset of data for debugging.
```python
# Overfit on 10 batches
trainer = L.Trainer(overfit_batches=10)
# Overfit on 1% of data
trainer = L.Trainer(overfit_batches=0.01)
```
#### `detect_anomaly` (bool)
Enable PyTorch anomaly detection for debugging NaNs. Default: False
```python
trainer = L.Trainer(detect_anomaly=True)
```
### Reproducibility
#### `deterministic` (bool or str)
Control deterministic behavior. Default: False
```python
import lightning as L
# Seed everything
L.seed_everything(42, workers=True)
# Fully deterministic (may impact performance)
trainer = L.Trainer(deterministic=True)
# Warn if non-deterministic operations detected
trainer = L.Trainer(deterministic="warn")
```
#### `benchmark` (bool)
Enable cudnn benchmarking for performance. Default: False
```python
trainer = L.Trainer(benchmark=True)
```
### Miscellaneous
#### `enable_model_summary` (bool)
Print model summary before training. Default: True
```python
trainer = L.Trainer(enable_model_summary=False)
```
#### `inference_mode` (bool)
Use torch.inference_mode() instead of torch.no_grad() for validation/test. Default: True
```python
trainer = L.Trainer(inference_mode=True)
```
#### `profiler` (str or Profiler)
Profile code for performance optimization. Options: "simple", "advanced", or custom Profiler.
```python
# Simple profiler
trainer = L.Trainer(profiler="simple")
# Advanced profiler
trainer = L.Trainer(profiler="advanced")
```
## Common Configurations
### Basic Training
```python
trainer = L.Trainer(
max_epochs=100,
accelerator="auto",
devices="auto"
)
trainer.fit(model, train_loader, val_loader)
```
### Multi-GPU Training
```python
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=4,
strategy="ddp",
precision="16-mixed"
)
trainer.fit(model, datamodule=dm)
```
### Production Training with Checkpoints
```python
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3,
save_last=True
)
early_stop = EarlyStopping(
monitor="val_loss",
patience=10,
mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=2,
strategy="ddp",
precision="16-mixed",
callbacks=[checkpoint_callback, early_stop, lr_monitor],
log_every_n_steps=10,
gradient_clip_val=1.0
)
trainer.fit(model, datamodule=dm)
```
### Debug Configuration
```python
trainer = L.Trainer(
fast_dev_run=True, # Run 1 batch
accelerator="cpu",
enable_progress_bar=True,
log_every_n_steps=1,
detect_anomaly=True
)
trainer.fit(model, train_loader, val_loader)
```
### Research Configuration (Reproducibility)
```python
import lightning as L
L.seed_everything(42, workers=True)
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=1,
deterministic=True,
benchmark=False,
precision="32-true"
)
trainer.fit(model, datamodule=dm)
```
### Time-Limited Training (Cluster)
```python
trainer = L.Trainer(
max_time={"hours": 23, "minutes": 30}, # SLURM time limit
max_epochs=1000,
callbacks=[ModelCheckpoint(save_last=True)]
)
trainer.fit(model, datamodule=dm)
# Resume from checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
```
### Large Model Training (FSDP)
```python
from lightning.pytorch.strategies import FSDPStrategy
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=8,
strategy=FSDPStrategy(
activation_checkpointing_policy={nn.TransformerEncoderLayer},
cpu_offload=False
),
precision="bf16-mixed",
accumulate_grad_batches=4
)
trainer.fit(model, datamodule=dm)
```
## Resuming Training
### From Checkpoint
```python
# Resume from specific checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="epoch=10-val_loss=0.23.ckpt")
# Resume from last checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
```
### Finding Last Checkpoint
```python
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(save_last=True)
trainer = L.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=dm)
# Get path to last checkpoint
last_checkpoint = checkpoint_callback.last_model_path
```
## Accessing Trainer from LightningModule
Inside a LightningModule, access the Trainer via `self.trainer`:
```python
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# Access trainer properties
current_epoch = self.trainer.current_epoch
global_step = self.trainer.global_step
max_epochs = self.trainer.max_epochs
# Access callbacks
for callback in self.trainer.callbacks:
if isinstance(callback, ModelCheckpoint):
print(f"Best model: {callback.best_model_path}")
# Access logger
self.trainer.logger.log_metrics({"custom": value})
```
## Trainer Attributes
| Attribute | Description |
|-----------|-------------|
| `trainer.current_epoch` | Current epoch (0-indexed) |
| `trainer.global_step` | Total optimizer steps |
| `trainer.max_epochs` | Maximum epochs configured |
| `trainer.max_steps` | Maximum steps configured |
| `trainer.callbacks` | List of callbacks |
| `trainer.logger` | Logger instance |
| `trainer.strategy` | Training strategy |
| `trainer.estimated_stepping_batches` | Estimated total steps for training |
## Best Practices
### 1. Start with Fast Dev Run
Always test with `fast_dev_run=True` before full training:
```python
trainer = L.Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)
```
### 2. Use Gradient Clipping
Prevent gradient explosions:
```python
trainer = L.Trainer(gradient_clip_val=1.0, gradient_clip_algorithm="norm")
```
### 3. Enable Mixed Precision
Speed up training on modern GPUs:
```python
trainer = L.Trainer(precision="16-mixed") # or "bf16-mixed" for A100+
```
### 4. Save Checkpoints Properly
Always save the last checkpoint for resuming:
```python
checkpoint_callback = ModelCheckpoint(
save_top_k=3,
save_last=True,
monitor="val_loss"
)
```
### 5. Monitor Learning Rate
Track LR changes with LearningRateMonitor:
```python
from lightning.pytorch.callbacks import LearningRateMonitor
trainer = L.Trainer(callbacks=[LearningRateMonitor(logging_interval="step")])
```
### 6. Use DataModule for Reproducibility
Encapsulate data logic in a DataModule:
```python
# Better than passing DataLoaders directly
trainer.fit(model, datamodule=dm)
```
### 7. Set Deterministic for Research
Ensure reproducibility for publications:
```python
L.seed_everything(42, workers=True)
trainer = L.Trainer(deterministic=True)
```