Files
2025-11-30 08:30:10 +08:00

725 lines
16 KiB
Markdown

# Best Practices - PyTorch Lightning
## Code Organization
### 1. Separate Research from Engineering
**Good:**
```python
class MyModel(L.LightningModule):
# Research code (what the model does)
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
return loss
# Engineering code (how to train) - in Trainer
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=4,
strategy="ddp"
)
```
**Bad:**
```python
# Mixing research and engineering logic
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Don't do device management manually
loss = loss.cuda()
# Don't do optimizer steps manually (unless manual optimization)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
```
### 2. Use LightningDataModule
**Good:**
```python
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def prepare_data(self):
# Download data once
download_data(self.data_dir)
def setup(self, stage):
# Load data per-process
self.train_dataset = MyDataset(self.data_dir, split='train')
self.val_dataset = MyDataset(self.data_dir, split='val')
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
# Reusable and shareable
dm = MyDataModule("./data", batch_size=32)
trainer.fit(model, datamodule=dm)
```
**Bad:**
```python
# Scattered data logic
train_dataset = load_data()
val_dataset = load_data()
train_loader = DataLoader(train_dataset, ...)
val_loader = DataLoader(val_dataset, ...)
trainer.fit(model, train_loader, val_loader)
```
### 3. Keep Models Modular
```python
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(...)
def forward(self, x):
return self.layers(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(...)
def forward(self, x):
return self.layers(x)
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
```
## Device Agnosticism
### 1. Never Use Explicit CUDA Calls
**Bad:**
```python
x = x.cuda()
model = model.cuda()
torch.cuda.set_device(0)
```
**Good:**
```python
# Inside LightningModule
x = x.to(self.device)
# Or let Lightning handle it automatically
def training_step(self, batch, batch_idx):
x, y = batch # Already on correct device
return loss
```
### 2. Use `self.device` Property
```python
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# Create tensors on correct device
noise = torch.randn(batch.size(0), 100).to(self.device)
# Or use type_as
noise = torch.randn(batch.size(0), 100).type_as(batch)
```
### 3. Register Buffers for Non-Parameters
```python
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Register buffers (automatically moved to correct device)
self.register_buffer("running_mean", torch.zeros(100))
def forward(self, x):
# self.running_mean is automatically on correct device
return x - self.running_mean
```
## Hyperparameter Management
### 1. Always Use `save_hyperparameters()`
**Good:**
```python
class MyModel(L.LightningModule):
def __init__(self, learning_rate, hidden_dim, dropout):
super().__init__()
self.save_hyperparameters() # Saves all arguments
# Access via self.hparams
self.model = nn.Linear(self.hparams.hidden_dim, 10)
# Load from checkpoint with saved hparams
model = MyModel.load_from_checkpoint("checkpoint.ckpt")
print(model.hparams.learning_rate) # Original value preserved
```
**Bad:**
```python
class MyModel(L.LightningModule):
def __init__(self, learning_rate, hidden_dim, dropout):
super().__init__()
self.learning_rate = learning_rate # Manual tracking
self.hidden_dim = hidden_dim
```
### 2. Ignore Specific Arguments
```python
class MyModel(L.LightningModule):
def __init__(self, lr, model, dataset):
super().__init__()
# Don't save 'model' and 'dataset' (not serializable)
self.save_hyperparameters(ignore=['model', 'dataset'])
self.model = model
self.dataset = dataset
```
### 3. Use Hyperparameters in `configure_optimizers()`
```python
def configure_optimizers(self):
# Use saved hyperparameters
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay
)
return optimizer
```
## Logging Best Practices
### 1. Log Both Step and Epoch Metrics
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Log per-step for detailed monitoring
# Log per-epoch for aggregated view
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
```
### 2. Use Structured Logging
```python
def training_step(self, batch, batch_idx):
# Organize with prefixes
self.log("train/loss", loss)
self.log("train/acc", acc)
self.log("train/f1", f1)
def validation_step(self, batch, batch_idx):
self.log("val/loss", loss)
self.log("val/acc", acc)
self.log("val/f1", f1)
```
### 3. Sync Metrics in Distributed Training
```python
def validation_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# IMPORTANT: sync_dist=True for proper aggregation across GPUs
self.log("val_loss", loss, sync_dist=True)
```
### 4. Monitor Learning Rate
```python
from lightning.pytorch.callbacks import LearningRateMonitor
trainer = L.Trainer(
callbacks=[LearningRateMonitor(logging_interval="step")]
)
```
## Reproducibility
### 1. Seed Everything
```python
import lightning as L
# Set seed for reproducibility
L.seed_everything(42, workers=True)
trainer = L.Trainer(
deterministic=True, # Use deterministic algorithms
benchmark=False # Disable cudnn benchmarking
)
```
### 2. Avoid Non-Deterministic Operations
```python
# Bad: Non-deterministic
torch.use_deterministic_algorithms(False)
# Good: Deterministic
torch.use_deterministic_algorithms(True)
```
### 3. Log Random State
```python
def on_save_checkpoint(self, checkpoint):
# Save random states
checkpoint['rng_state'] = {
'torch': torch.get_rng_state(),
'numpy': np.random.get_state(),
'python': random.getstate()
}
def on_load_checkpoint(self, checkpoint):
# Restore random states
if 'rng_state' in checkpoint:
torch.set_rng_state(checkpoint['rng_state']['torch'])
np.random.set_state(checkpoint['rng_state']['numpy'])
random.setstate(checkpoint['rng_state']['python'])
```
## Debugging
### 1. Use `fast_dev_run`
```python
# Test with 1 batch before full training
trainer = L.Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)
```
### 2. Limit Training Data
```python
# Use only 10% of data for quick iteration
trainer = L.Trainer(
limit_train_batches=0.1,
limit_val_batches=0.1
)
```
### 3. Enable Anomaly Detection
```python
# Detect NaN/Inf in gradients
trainer = L.Trainer(detect_anomaly=True)
```
### 4. Overfit on Small Batch
```python
# Overfit on 10 batches to verify model capacity
trainer = L.Trainer(overfit_batches=10)
```
### 5. Profile Code
```python
# Find performance bottlenecks
trainer = L.Trainer(profiler="simple") # or "advanced"
```
## Memory Optimization
### 1. Use Mixed Precision
```python
# FP16/BF16 mixed precision for memory savings and speed
trainer = L.Trainer(
precision="16-mixed", # V100, T4
# or
precision="bf16-mixed" # A100, H100
)
```
### 2. Gradient Accumulation
```python
# Simulate larger batch size without memory increase
trainer = L.Trainer(
accumulate_grad_batches=4 # Accumulate over 4 batches
)
```
### 3. Gradient Checkpointing
```python
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = transformers.AutoModel.from_pretrained("bert-base")
# Enable gradient checkpointing
self.model.gradient_checkpointing_enable()
```
### 4. Clear Cache
```python
def on_train_epoch_end(self):
# Clear collected outputs to free memory
self.training_step_outputs.clear()
# Clear CUDA cache if needed
if torch.cuda.is_available():
torch.cuda.empty_cache()
```
### 5. Use Efficient Data Types
```python
# Use appropriate precision
# FP32 for stability, FP16/BF16 for speed/memory
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Use bfloat16 for better numerical stability than fp16
self.model = MyTransformer().to(torch.bfloat16)
```
## Training Stability
### 1. Gradient Clipping
```python
# Prevent gradient explosion
trainer = L.Trainer(
gradient_clip_val=1.0,
gradient_clip_algorithm="norm" # or "value"
)
```
### 2. Learning Rate Warmup
```python
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=1e-2,
total_steps=self.trainer.estimated_stepping_batches,
pct_start=0.1 # 10% warmup
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step"
}
}
```
### 3. Monitor Gradients
```python
class MyModel(L.LightningModule):
def on_after_backward(self):
# Log gradient norms
for name, param in self.named_parameters():
if param.grad is not None:
self.log(f"grad_norm/{name}", param.grad.norm())
```
### 4. Use EarlyStopping
```python
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor="val_loss",
patience=10,
mode="min",
verbose=True
)
trainer = L.Trainer(callbacks=[early_stop])
```
## Checkpointing
### 1. Save Top-K and Last
```python
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3, # Keep best 3
save_last=True # Always save last for resuming
)
trainer = L.Trainer(callbacks=[checkpoint_callback])
```
### 2. Resume Training
```python
# Resume from last checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
# Resume from specific checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="epoch=10-val_loss=0.23.ckpt")
```
### 3. Custom Checkpoint State
```python
def on_save_checkpoint(self, checkpoint):
# Add custom state
checkpoint['custom_data'] = self.custom_data
checkpoint['epoch_metrics'] = self.metrics
def on_load_checkpoint(self, checkpoint):
# Restore custom state
self.custom_data = checkpoint.get('custom_data', {})
self.metrics = checkpoint.get('epoch_metrics', [])
```
## Testing
### 1. Separate Train and Test
```python
# Train
trainer = L.Trainer(max_epochs=100)
trainer.fit(model, datamodule=dm)
# Test ONLY ONCE before publishing
trainer.test(model, datamodule=dm)
```
### 2. Use Validation for Model Selection
```python
# Use validation for hyperparameter tuning
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
trainer = L.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=dm)
# Load best model
best_model = MyModel.load_from_checkpoint(checkpoint_callback.best_model_path)
# Test only once with best model
trainer.test(best_model, datamodule=dm)
```
## Code Quality
### 1. Type Hints
```python
from typing import Any, Dict, Tuple
import torch
from torch import Tensor
class MyModel(L.LightningModule):
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
x, y = batch
loss = self.compute_loss(x, y)
return loss
def configure_optimizers(self) -> Dict[str, Any]:
optimizer = torch.optim.Adam(self.parameters())
return {"optimizer": optimizer}
```
### 2. Docstrings
```python
class MyModel(L.LightningModule):
"""
My awesome model for image classification.
Args:
num_classes: Number of output classes
learning_rate: Learning rate for optimizer
hidden_dim: Hidden dimension size
"""
def __init__(self, num_classes: int, learning_rate: float, hidden_dim: int):
super().__init__()
self.save_hyperparameters()
```
### 3. Property Methods
```python
class MyModel(L.LightningModule):
@property
def learning_rate(self) -> float:
"""Current learning rate."""
return self.hparams.learning_rate
@property
def num_parameters(self) -> int:
"""Total number of parameters."""
return sum(p.numel() for p in self.parameters())
```
## Common Pitfalls
### 1. Forgetting to Return Loss
**Bad:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss)
# FORGOT TO RETURN LOSS!
```
**Good:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss)
return loss # MUST return loss
```
### 2. Not Syncing Metrics in DDP
**Bad:**
```python
def validation_step(self, batch, batch_idx):
self.log("val_acc", acc) # Wrong value with multi-GPU!
```
**Good:**
```python
def validation_step(self, batch, batch_idx):
self.log("val_acc", acc, sync_dist=True) # Correct aggregation
```
### 3. Manual Device Management
**Bad:**
```python
def training_step(self, batch, batch_idx):
x = x.cuda() # Don't do this
y = y.cuda()
```
**Good:**
```python
def training_step(self, batch, batch_idx):
# Lightning handles device placement
x, y = batch # Already on correct device
```
### 4. Not Using `self.log()`
**Bad:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.training_losses.append(loss) # Manual tracking
return loss
```
**Good:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss) # Automatic logging
return loss
```
### 5. Modifying Batch In-Place
**Bad:**
```python
def training_step(self, batch, batch_idx):
x, y = batch
x[:] = self.augment(x) # In-place modification can cause issues
```
**Good:**
```python
def training_step(self, batch, batch_idx):
x, y = batch
x = self.augment(x) # Create new tensor
```
## Performance Tips
### 1. Use DataLoader Workers
```python
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=32,
num_workers=4, # Use multiple workers
pin_memory=True, # Faster GPU transfer
persistent_workers=True # Keep workers alive
)
```
### 2. Enable Benchmark Mode (if fixed input size)
```python
trainer = L.Trainer(benchmark=True)
```
### 3. Use Automatic Batch Size Finding
```python
from lightning.pytorch.tuner import Tuner
trainer = L.Trainer()
tuner = Tuner(trainer)
# Find optimal batch size
tuner.scale_batch_size(model, datamodule=dm, mode="power")
# Then train
trainer.fit(model, datamodule=dm)
```
### 4. Optimize Data Loading
```python
# Use faster image decoding
import torch
import torchvision.transforms as T
transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Use PIL-SIMD for faster image loading
# pip install pillow-simd
```