16 KiB
16 KiB
Best Practices - PyTorch Lightning
Code Organization
1. Separate Research from Engineering
Good:
class MyModel(L.LightningModule):
# Research code (what the model does)
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
return loss
# Engineering code (how to train) - in Trainer
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=4,
strategy="ddp"
)
Bad:
# Mixing research and engineering logic
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Don't do device management manually
loss = loss.cuda()
# Don't do optimizer steps manually (unless manual optimization)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
2. Use LightningDataModule
Good:
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def prepare_data(self):
# Download data once
download_data(self.data_dir)
def setup(self, stage):
# Load data per-process
self.train_dataset = MyDataset(self.data_dir, split='train')
self.val_dataset = MyDataset(self.data_dir, split='val')
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
# Reusable and shareable
dm = MyDataModule("./data", batch_size=32)
trainer.fit(model, datamodule=dm)
Bad:
# Scattered data logic
train_dataset = load_data()
val_dataset = load_data()
train_loader = DataLoader(train_dataset, ...)
val_loader = DataLoader(val_dataset, ...)
trainer.fit(model, train_loader, val_loader)
3. Keep Models Modular
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(...)
def forward(self, x):
return self.layers(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(...)
def forward(self, x):
return self.layers(x)
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
Device Agnosticism
1. Never Use Explicit CUDA Calls
Bad:
x = x.cuda()
model = model.cuda()
torch.cuda.set_device(0)
Good:
# Inside LightningModule
x = x.to(self.device)
# Or let Lightning handle it automatically
def training_step(self, batch, batch_idx):
x, y = batch # Already on correct device
return loss
2. Use self.device Property
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# Create tensors on correct device
noise = torch.randn(batch.size(0), 100).to(self.device)
# Or use type_as
noise = torch.randn(batch.size(0), 100).type_as(batch)
3. Register Buffers for Non-Parameters
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Register buffers (automatically moved to correct device)
self.register_buffer("running_mean", torch.zeros(100))
def forward(self, x):
# self.running_mean is automatically on correct device
return x - self.running_mean
Hyperparameter Management
1. Always Use save_hyperparameters()
Good:
class MyModel(L.LightningModule):
def __init__(self, learning_rate, hidden_dim, dropout):
super().__init__()
self.save_hyperparameters() # Saves all arguments
# Access via self.hparams
self.model = nn.Linear(self.hparams.hidden_dim, 10)
# Load from checkpoint with saved hparams
model = MyModel.load_from_checkpoint("checkpoint.ckpt")
print(model.hparams.learning_rate) # Original value preserved
Bad:
class MyModel(L.LightningModule):
def __init__(self, learning_rate, hidden_dim, dropout):
super().__init__()
self.learning_rate = learning_rate # Manual tracking
self.hidden_dim = hidden_dim
2. Ignore Specific Arguments
class MyModel(L.LightningModule):
def __init__(self, lr, model, dataset):
super().__init__()
# Don't save 'model' and 'dataset' (not serializable)
self.save_hyperparameters(ignore=['model', 'dataset'])
self.model = model
self.dataset = dataset
3. Use Hyperparameters in configure_optimizers()
def configure_optimizers(self):
# Use saved hyperparameters
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay
)
return optimizer
Logging Best Practices
1. Log Both Step and Epoch Metrics
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Log per-step for detailed monitoring
# Log per-epoch for aggregated view
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
2. Use Structured Logging
def training_step(self, batch, batch_idx):
# Organize with prefixes
self.log("train/loss", loss)
self.log("train/acc", acc)
self.log("train/f1", f1)
def validation_step(self, batch, batch_idx):
self.log("val/loss", loss)
self.log("val/acc", acc)
self.log("val/f1", f1)
3. Sync Metrics in Distributed Training
def validation_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# IMPORTANT: sync_dist=True for proper aggregation across GPUs
self.log("val_loss", loss, sync_dist=True)
4. Monitor Learning Rate
from lightning.pytorch.callbacks import LearningRateMonitor
trainer = L.Trainer(
callbacks=[LearningRateMonitor(logging_interval="step")]
)
Reproducibility
1. Seed Everything
import lightning as L
# Set seed for reproducibility
L.seed_everything(42, workers=True)
trainer = L.Trainer(
deterministic=True, # Use deterministic algorithms
benchmark=False # Disable cudnn benchmarking
)
2. Avoid Non-Deterministic Operations
# Bad: Non-deterministic
torch.use_deterministic_algorithms(False)
# Good: Deterministic
torch.use_deterministic_algorithms(True)
3. Log Random State
def on_save_checkpoint(self, checkpoint):
# Save random states
checkpoint['rng_state'] = {
'torch': torch.get_rng_state(),
'numpy': np.random.get_state(),
'python': random.getstate()
}
def on_load_checkpoint(self, checkpoint):
# Restore random states
if 'rng_state' in checkpoint:
torch.set_rng_state(checkpoint['rng_state']['torch'])
np.random.set_state(checkpoint['rng_state']['numpy'])
random.setstate(checkpoint['rng_state']['python'])
Debugging
1. Use fast_dev_run
# Test with 1 batch before full training
trainer = L.Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)
2. Limit Training Data
# Use only 10% of data for quick iteration
trainer = L.Trainer(
limit_train_batches=0.1,
limit_val_batches=0.1
)
3. Enable Anomaly Detection
# Detect NaN/Inf in gradients
trainer = L.Trainer(detect_anomaly=True)
4. Overfit on Small Batch
# Overfit on 10 batches to verify model capacity
trainer = L.Trainer(overfit_batches=10)
5. Profile Code
# Find performance bottlenecks
trainer = L.Trainer(profiler="simple") # or "advanced"
Memory Optimization
1. Use Mixed Precision
# FP16/BF16 mixed precision for memory savings and speed
trainer = L.Trainer(
precision="16-mixed", # V100, T4
# or
precision="bf16-mixed" # A100, H100
)
2. Gradient Accumulation
# Simulate larger batch size without memory increase
trainer = L.Trainer(
accumulate_grad_batches=4 # Accumulate over 4 batches
)
3. Gradient Checkpointing
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = transformers.AutoModel.from_pretrained("bert-base")
# Enable gradient checkpointing
self.model.gradient_checkpointing_enable()
4. Clear Cache
def on_train_epoch_end(self):
# Clear collected outputs to free memory
self.training_step_outputs.clear()
# Clear CUDA cache if needed
if torch.cuda.is_available():
torch.cuda.empty_cache()
5. Use Efficient Data Types
# Use appropriate precision
# FP32 for stability, FP16/BF16 for speed/memory
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Use bfloat16 for better numerical stability than fp16
self.model = MyTransformer().to(torch.bfloat16)
Training Stability
1. Gradient Clipping
# Prevent gradient explosion
trainer = L.Trainer(
gradient_clip_val=1.0,
gradient_clip_algorithm="norm" # or "value"
)
2. Learning Rate Warmup
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=1e-2,
total_steps=self.trainer.estimated_stepping_batches,
pct_start=0.1 # 10% warmup
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step"
}
}
3. Monitor Gradients
class MyModel(L.LightningModule):
def on_after_backward(self):
# Log gradient norms
for name, param in self.named_parameters():
if param.grad is not None:
self.log(f"grad_norm/{name}", param.grad.norm())
4. Use EarlyStopping
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor="val_loss",
patience=10,
mode="min",
verbose=True
)
trainer = L.Trainer(callbacks=[early_stop])
Checkpointing
1. Save Top-K and Last
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3, # Keep best 3
save_last=True # Always save last for resuming
)
trainer = L.Trainer(callbacks=[checkpoint_callback])
2. Resume Training
# Resume from last checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="last.ckpt")
# Resume from specific checkpoint
trainer.fit(model, datamodule=dm, ckpt_path="epoch=10-val_loss=0.23.ckpt")
3. Custom Checkpoint State
def on_save_checkpoint(self, checkpoint):
# Add custom state
checkpoint['custom_data'] = self.custom_data
checkpoint['epoch_metrics'] = self.metrics
def on_load_checkpoint(self, checkpoint):
# Restore custom state
self.custom_data = checkpoint.get('custom_data', {})
self.metrics = checkpoint.get('epoch_metrics', [])
Testing
1. Separate Train and Test
# Train
trainer = L.Trainer(max_epochs=100)
trainer.fit(model, datamodule=dm)
# Test ONLY ONCE before publishing
trainer.test(model, datamodule=dm)
2. Use Validation for Model Selection
# Use validation for hyperparameter tuning
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
trainer = L.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=dm)
# Load best model
best_model = MyModel.load_from_checkpoint(checkpoint_callback.best_model_path)
# Test only once with best model
trainer.test(best_model, datamodule=dm)
Code Quality
1. Type Hints
from typing import Any, Dict, Tuple
import torch
from torch import Tensor
class MyModel(L.LightningModule):
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
x, y = batch
loss = self.compute_loss(x, y)
return loss
def configure_optimizers(self) -> Dict[str, Any]:
optimizer = torch.optim.Adam(self.parameters())
return {"optimizer": optimizer}
2. Docstrings
class MyModel(L.LightningModule):
"""
My awesome model for image classification.
Args:
num_classes: Number of output classes
learning_rate: Learning rate for optimizer
hidden_dim: Hidden dimension size
"""
def __init__(self, num_classes: int, learning_rate: float, hidden_dim: int):
super().__init__()
self.save_hyperparameters()
3. Property Methods
class MyModel(L.LightningModule):
@property
def learning_rate(self) -> float:
"""Current learning rate."""
return self.hparams.learning_rate
@property
def num_parameters(self) -> int:
"""Total number of parameters."""
return sum(p.numel() for p in self.parameters())
Common Pitfalls
1. Forgetting to Return Loss
Bad:
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss)
# FORGOT TO RETURN LOSS!
Good:
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss)
return loss # MUST return loss
2. Not Syncing Metrics in DDP
Bad:
def validation_step(self, batch, batch_idx):
self.log("val_acc", acc) # Wrong value with multi-GPU!
Good:
def validation_step(self, batch, batch_idx):
self.log("val_acc", acc, sync_dist=True) # Correct aggregation
3. Manual Device Management
Bad:
def training_step(self, batch, batch_idx):
x = x.cuda() # Don't do this
y = y.cuda()
Good:
def training_step(self, batch, batch_idx):
# Lightning handles device placement
x, y = batch # Already on correct device
4. Not Using self.log()
Bad:
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.training_losses.append(loss) # Manual tracking
return loss
Good:
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss) # Automatic logging
return loss
5. Modifying Batch In-Place
Bad:
def training_step(self, batch, batch_idx):
x, y = batch
x[:] = self.augment(x) # In-place modification can cause issues
Good:
def training_step(self, batch, batch_idx):
x, y = batch
x = self.augment(x) # Create new tensor
Performance Tips
1. Use DataLoader Workers
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=32,
num_workers=4, # Use multiple workers
pin_memory=True, # Faster GPU transfer
persistent_workers=True # Keep workers alive
)
2. Enable Benchmark Mode (if fixed input size)
trainer = L.Trainer(benchmark=True)
3. Use Automatic Batch Size Finding
from lightning.pytorch.tuner import Tuner
trainer = L.Trainer()
tuner = Tuner(trainer)
# Find optimal batch size
tuner.scale_batch_size(model, datamodule=dm, mode="power")
# Then train
trainer.fit(model, datamodule=dm)
4. Optimize Data Loading
# Use faster image decoding
import torch
import torchvision.transforms as T
transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Use PIL-SIMD for faster image loading
# pip install pillow-simd