455 lines
13 KiB
Python
455 lines
13 KiB
Python
"""
|
|
Quick Trainer Setup Examples for PyTorch Lightning.
|
|
|
|
This script provides ready-to-use Trainer configurations for common use cases.
|
|
Copy and modify these configurations for your specific needs.
|
|
"""
|
|
|
|
import lightning as L
|
|
from lightning.pytorch.callbacks import (
|
|
ModelCheckpoint,
|
|
EarlyStopping,
|
|
LearningRateMonitor,
|
|
DeviceStatsMonitor,
|
|
RichProgressBar,
|
|
)
|
|
from lightning.pytorch import loggers as pl_loggers
|
|
from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy
|
|
|
|
|
|
# =============================================================================
|
|
# 1. BASIC TRAINING (Single GPU/CPU)
|
|
# =============================================================================
|
|
|
|
def basic_trainer():
|
|
"""
|
|
Simple trainer for quick prototyping.
|
|
Use for: Small models, debugging, single GPU training
|
|
"""
|
|
trainer = L.Trainer(
|
|
max_epochs=10,
|
|
accelerator="auto", # Automatically select GPU/CPU
|
|
devices="auto", # Use all available devices
|
|
enable_progress_bar=True,
|
|
logger=True,
|
|
)
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 2. DEBUGGING CONFIGURATION
|
|
# =============================================================================
|
|
|
|
def debug_trainer():
|
|
"""
|
|
Trainer for debugging with fast dev run and anomaly detection.
|
|
Use for: Finding bugs, testing code quickly
|
|
"""
|
|
trainer = L.Trainer(
|
|
fast_dev_run=True, # Run 1 batch through train/val/test
|
|
accelerator="cpu", # Use CPU for easier debugging
|
|
detect_anomaly=True, # Detect NaN/Inf in gradients
|
|
log_every_n_steps=1, # Log every step
|
|
enable_progress_bar=True,
|
|
)
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 3. PRODUCTION TRAINING (Single GPU)
|
|
# =============================================================================
|
|
|
|
def production_single_gpu_trainer(
|
|
max_epochs=100,
|
|
log_dir="logs",
|
|
checkpoint_dir="checkpoints"
|
|
):
|
|
"""
|
|
Production-ready trainer for single GPU with checkpointing and logging.
|
|
Use for: Final training runs on single GPU
|
|
"""
|
|
# Callbacks
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=checkpoint_dir,
|
|
filename="{epoch:02d}-{val_loss:.2f}",
|
|
monitor="val_loss",
|
|
mode="min",
|
|
save_top_k=3,
|
|
save_last=True,
|
|
verbose=True,
|
|
)
|
|
|
|
early_stop_callback = EarlyStopping(
|
|
monitor="val_loss",
|
|
patience=10,
|
|
mode="min",
|
|
verbose=True,
|
|
)
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval="step")
|
|
|
|
# Logger
|
|
tb_logger = pl_loggers.TensorBoardLogger(
|
|
save_dir=log_dir,
|
|
name="my_model",
|
|
)
|
|
|
|
# Trainer
|
|
trainer = L.Trainer(
|
|
max_epochs=max_epochs,
|
|
accelerator="gpu",
|
|
devices=1,
|
|
precision="16-mixed", # Mixed precision for speed
|
|
callbacks=[
|
|
checkpoint_callback,
|
|
early_stop_callback,
|
|
lr_monitor,
|
|
],
|
|
logger=tb_logger,
|
|
log_every_n_steps=50,
|
|
gradient_clip_val=1.0, # Clip gradients
|
|
enable_progress_bar=True,
|
|
)
|
|
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 4. MULTI-GPU TRAINING (DDP)
|
|
# =============================================================================
|
|
|
|
def multi_gpu_ddp_trainer(
|
|
max_epochs=100,
|
|
num_gpus=4,
|
|
log_dir="logs",
|
|
checkpoint_dir="checkpoints"
|
|
):
|
|
"""
|
|
Multi-GPU training with Distributed Data Parallel.
|
|
Use for: Models <500M parameters, standard deep learning models
|
|
"""
|
|
# Callbacks
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=checkpoint_dir,
|
|
filename="{epoch:02d}-{val_loss:.2f}",
|
|
monitor="val_loss",
|
|
mode="min",
|
|
save_top_k=3,
|
|
save_last=True,
|
|
)
|
|
|
|
early_stop_callback = EarlyStopping(
|
|
monitor="val_loss",
|
|
patience=10,
|
|
mode="min",
|
|
)
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval="step")
|
|
|
|
# Logger
|
|
wandb_logger = pl_loggers.WandbLogger(
|
|
project="my-project",
|
|
save_dir=log_dir,
|
|
)
|
|
|
|
# Trainer
|
|
trainer = L.Trainer(
|
|
max_epochs=max_epochs,
|
|
accelerator="gpu",
|
|
devices=num_gpus,
|
|
strategy=DDPStrategy(
|
|
find_unused_parameters=False,
|
|
gradient_as_bucket_view=True,
|
|
),
|
|
precision="16-mixed",
|
|
callbacks=[
|
|
checkpoint_callback,
|
|
early_stop_callback,
|
|
lr_monitor,
|
|
],
|
|
logger=wandb_logger,
|
|
log_every_n_steps=50,
|
|
gradient_clip_val=1.0,
|
|
sync_batchnorm=True, # Sync batch norm across GPUs
|
|
)
|
|
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 5. LARGE MODEL TRAINING (FSDP)
|
|
# =============================================================================
|
|
|
|
def large_model_fsdp_trainer(
|
|
max_epochs=100,
|
|
num_gpus=8,
|
|
log_dir="logs",
|
|
checkpoint_dir="checkpoints"
|
|
):
|
|
"""
|
|
Training for large models (500M+ parameters) with FSDP.
|
|
Use for: Large transformers, models that don't fit in single GPU
|
|
"""
|
|
import torch.nn as nn
|
|
|
|
# Callbacks
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=checkpoint_dir,
|
|
filename="{epoch:02d}-{val_loss:.2f}",
|
|
monitor="val_loss",
|
|
mode="min",
|
|
save_top_k=3,
|
|
save_last=True,
|
|
)
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval="step")
|
|
|
|
# Logger
|
|
wandb_logger = pl_loggers.WandbLogger(
|
|
project="large-model",
|
|
save_dir=log_dir,
|
|
)
|
|
|
|
# Trainer with FSDP
|
|
trainer = L.Trainer(
|
|
max_epochs=max_epochs,
|
|
accelerator="gpu",
|
|
devices=num_gpus,
|
|
strategy=FSDPStrategy(
|
|
sharding_strategy="FULL_SHARD",
|
|
activation_checkpointing_policy={
|
|
nn.TransformerEncoderLayer,
|
|
nn.TransformerDecoderLayer,
|
|
},
|
|
cpu_offload=False, # Set True if GPU memory insufficient
|
|
),
|
|
precision="bf16-mixed", # BFloat16 for A100/H100
|
|
callbacks=[
|
|
checkpoint_callback,
|
|
lr_monitor,
|
|
],
|
|
logger=wandb_logger,
|
|
log_every_n_steps=10,
|
|
gradient_clip_val=1.0,
|
|
accumulate_grad_batches=4, # Gradient accumulation
|
|
)
|
|
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 6. VERY LARGE MODEL TRAINING (DeepSpeed)
|
|
# =============================================================================
|
|
|
|
def deepspeed_trainer(
|
|
max_epochs=100,
|
|
num_gpus=8,
|
|
stage=3,
|
|
log_dir="logs",
|
|
checkpoint_dir="checkpoints"
|
|
):
|
|
"""
|
|
Training for very large models with DeepSpeed.
|
|
Use for: Models >10B parameters, maximum memory efficiency
|
|
"""
|
|
# Callbacks
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=checkpoint_dir,
|
|
filename="{epoch:02d}-{step:06d}",
|
|
save_top_k=3,
|
|
save_last=True,
|
|
every_n_train_steps=1000, # Save every N steps
|
|
)
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval="step")
|
|
|
|
# Logger
|
|
wandb_logger = pl_loggers.WandbLogger(
|
|
project="very-large-model",
|
|
save_dir=log_dir,
|
|
)
|
|
|
|
# Select DeepSpeed stage
|
|
strategy = f"deepspeed_stage_{stage}"
|
|
|
|
# Trainer
|
|
trainer = L.Trainer(
|
|
max_epochs=max_epochs,
|
|
accelerator="gpu",
|
|
devices=num_gpus,
|
|
strategy=strategy,
|
|
precision="16-mixed",
|
|
callbacks=[
|
|
checkpoint_callback,
|
|
lr_monitor,
|
|
],
|
|
logger=wandb_logger,
|
|
log_every_n_steps=10,
|
|
gradient_clip_val=1.0,
|
|
accumulate_grad_batches=4,
|
|
)
|
|
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 7. HYPERPARAMETER TUNING
|
|
# =============================================================================
|
|
|
|
def hyperparameter_tuning_trainer(max_epochs=50):
|
|
"""
|
|
Lightweight trainer for hyperparameter search.
|
|
Use for: Quick experiments, hyperparameter tuning
|
|
"""
|
|
trainer = L.Trainer(
|
|
max_epochs=max_epochs,
|
|
accelerator="auto",
|
|
devices=1,
|
|
enable_checkpointing=False, # Don't save checkpoints
|
|
logger=False, # Disable logging
|
|
enable_progress_bar=False,
|
|
limit_train_batches=0.5, # Use 50% of training data
|
|
limit_val_batches=0.5, # Use 50% of validation data
|
|
)
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 8. OVERFITTING TEST
|
|
# =============================================================================
|
|
|
|
def overfit_test_trainer(num_batches=10):
|
|
"""
|
|
Trainer for overfitting on small subset to verify model capacity.
|
|
Use for: Testing if model can learn, debugging
|
|
"""
|
|
trainer = L.Trainer(
|
|
max_epochs=100,
|
|
accelerator="auto",
|
|
devices=1,
|
|
overfit_batches=num_batches, # Overfit on N batches
|
|
log_every_n_steps=1,
|
|
enable_progress_bar=True,
|
|
)
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 9. TIME-LIMITED TRAINING (SLURM)
|
|
# =============================================================================
|
|
|
|
def time_limited_trainer(
|
|
max_time_hours=23.5,
|
|
max_epochs=1000,
|
|
checkpoint_dir="checkpoints"
|
|
):
|
|
"""
|
|
Training with time limit for SLURM clusters.
|
|
Use for: Cluster jobs with time limits
|
|
"""
|
|
from datetime import timedelta
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=checkpoint_dir,
|
|
save_top_k=3,
|
|
save_last=True, # Important for resuming
|
|
every_n_epochs=5,
|
|
)
|
|
|
|
trainer = L.Trainer(
|
|
max_epochs=max_epochs,
|
|
max_time=timedelta(hours=max_time_hours),
|
|
accelerator="gpu",
|
|
devices="auto",
|
|
callbacks=[checkpoint_callback],
|
|
log_every_n_steps=50,
|
|
)
|
|
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# 10. REPRODUCIBLE RESEARCH
|
|
# =============================================================================
|
|
|
|
def reproducible_trainer(seed=42, max_epochs=100):
|
|
"""
|
|
Fully reproducible trainer for research papers.
|
|
Use for: Publications, reproducible results
|
|
"""
|
|
# Set seed
|
|
L.seed_everything(seed, workers=True)
|
|
|
|
# Callbacks
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath="checkpoints",
|
|
filename="{epoch:02d}-{val_loss:.2f}",
|
|
monitor="val_loss",
|
|
mode="min",
|
|
save_top_k=3,
|
|
save_last=True,
|
|
)
|
|
|
|
# Trainer
|
|
trainer = L.Trainer(
|
|
max_epochs=max_epochs,
|
|
accelerator="gpu",
|
|
devices=1,
|
|
precision="32-true", # Full precision for reproducibility
|
|
deterministic=True, # Use deterministic algorithms
|
|
benchmark=False, # Disable cudnn benchmarking
|
|
callbacks=[checkpoint_callback],
|
|
log_every_n_steps=50,
|
|
)
|
|
|
|
return trainer
|
|
|
|
|
|
# =============================================================================
|
|
# USAGE EXAMPLES
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
print("PyTorch Lightning Trainer Configurations\n")
|
|
|
|
# Example 1: Basic training
|
|
print("1. Basic Trainer:")
|
|
trainer = basic_trainer()
|
|
print(f" - Max epochs: {trainer.max_epochs}")
|
|
print(f" - Accelerator: {trainer.accelerator}")
|
|
print()
|
|
|
|
# Example 2: Debug training
|
|
print("2. Debug Trainer:")
|
|
trainer = debug_trainer()
|
|
print(f" - Fast dev run: {trainer.fast_dev_run}")
|
|
print(f" - Detect anomaly: {trainer.detect_anomaly}")
|
|
print()
|
|
|
|
# Example 3: Production single GPU
|
|
print("3. Production Single GPU Trainer:")
|
|
trainer = production_single_gpu_trainer(max_epochs=100)
|
|
print(f" - Max epochs: {trainer.max_epochs}")
|
|
print(f" - Precision: {trainer.precision}")
|
|
print(f" - Callbacks: {len(trainer.callbacks)}")
|
|
print()
|
|
|
|
# Example 4: Multi-GPU DDP
|
|
print("4. Multi-GPU DDP Trainer:")
|
|
trainer = multi_gpu_ddp_trainer(num_gpus=4)
|
|
print(f" - Strategy: {trainer.strategy}")
|
|
print(f" - Devices: {trainer.num_devices}")
|
|
print()
|
|
|
|
# Example 5: FSDP for large models
|
|
print("5. FSDP Trainer for Large Models:")
|
|
trainer = large_model_fsdp_trainer(num_gpus=8)
|
|
print(f" - Strategy: {trainer.strategy}")
|
|
print(f" - Precision: {trainer.precision}")
|
|
print()
|
|
|
|
print("\nTo use these configurations:")
|
|
print("1. Import the desired function")
|
|
print("2. Create trainer: trainer = production_single_gpu_trainer()")
|
|
print("3. Train model: trainer.fit(model, datamodule=dm)")
|