883 lines
30 KiB
Markdown
883 lines
30 KiB
Markdown
|
|
# Training Loop Architecture
|
|
|
|
## Overview
|
|
|
|
**Core Principle:** A properly structured training loop is the foundation of all successful deep learning projects. Success requires: (1) correct train/val/test data separation, (2) validation after EVERY epoch (not just once), (3) complete checkpoint state (model + optimizer + scheduler), (4) comprehensive logging/monitoring, and (5) graceful error handling. Poor loop structure causes: silent overfitting, broken resume functionality, undetectable training issues, and memory leaks.
|
|
|
|
Training loop failures manifest as: overfitting with good metrics, crashes on resume, unexplained loss spikes, or out-of-memory errors. These stem from misunderstanding when validation runs, what state must be saved, or how to manage GPU memory. Systematic architecture beats trial-and-error fixes.
|
|
|
|
## When to Use
|
|
|
|
**Use this skill when:**
|
|
- Implementing a new training loop from scratch
|
|
- Training loop is crashing unexpectedly
|
|
- Can't resume training from checkpoint correctly
|
|
- Model overfits but validation metrics look good
|
|
- Out-of-memory errors during training
|
|
- Unsure about train/val/test data split
|
|
- Need to monitor training progress properly
|
|
- Implementing early stopping or checkpoint selection
|
|
- Training loops show loss spikes or divergence on resume
|
|
- Adding logging/monitoring to training
|
|
|
|
**Don't use when:**
|
|
- Debugging single backward pass (use gradient-management skill)
|
|
- Tuning learning rate (use learning-rate-scheduling skill)
|
|
- Fixing specific loss function (use loss-functions-and-objectives skill)
|
|
- Data loading issues (use data-augmentation-strategies skill)
|
|
|
|
**Symptoms triggering this skill:**
|
|
- "Training loss decreases but validation loss increases (overfitting)"
|
|
- "Training crashes when resuming from checkpoint"
|
|
- "Out of memory errors after epoch 20"
|
|
- "I validated on training data and didn't realize"
|
|
- "Can't detect overfitting because I don't validate"
|
|
- "Training loss spikes when resuming"
|
|
- "My checkpoint doesn't load correctly"
|
|
|
|
|
|
## Complete Training Loop Structure
|
|
|
|
### 1. The Standard Training Loop (The Reference)
|
|
|
|
```python
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
# Setup logging (ALWAYS do this first)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TrainingLoop:
|
|
"""Complete training loop with validation, checkpointing, and monitoring."""
|
|
|
|
def __init__(self, model, optimizer, scheduler, criterion, device='cuda'):
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
self.scheduler = scheduler
|
|
self.criterion = criterion
|
|
self.device = device
|
|
|
|
# Tracking metrics
|
|
self.train_losses = []
|
|
self.val_losses = []
|
|
self.best_val_loss = float('inf')
|
|
self.epochs_without_improvement = 0
|
|
|
|
def train_epoch(self, train_loader):
|
|
"""Train for one epoch."""
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
for batch_idx, (data, target) in enumerate(train_loader):
|
|
data, target = data.to(self.device), target.to(self.device)
|
|
|
|
# Forward pass
|
|
self.optimizer.zero_grad()
|
|
output = self.model(data)
|
|
loss = self.criterion(output, target)
|
|
|
|
# Backward pass with gradient clipping (if needed)
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
self.optimizer.step()
|
|
|
|
# Accumulate loss
|
|
total_loss += loss.item()
|
|
num_batches += 1
|
|
|
|
# Log progress every 10 batches
|
|
if batch_idx % 10 == 0:
|
|
logger.debug(f"Batch {batch_idx}: loss={loss.item():.4f}")
|
|
|
|
avg_loss = total_loss / num_batches
|
|
return avg_loss
|
|
|
|
def validate_epoch(self, val_loader):
|
|
"""Validate on validation set (AFTER each epoch, not during)."""
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
with torch.no_grad(): # ✅ CRITICAL: No gradients during validation
|
|
for data, target in val_loader:
|
|
data, target = data.to(self.device), target.to(self.device)
|
|
|
|
output = self.model(data)
|
|
loss = self.criterion(output, target)
|
|
|
|
total_loss += loss.item()
|
|
num_batches += 1
|
|
|
|
avg_loss = total_loss / num_batches
|
|
return avg_loss
|
|
|
|
def save_checkpoint(self, epoch, val_loss, checkpoint_dir='checkpoints'):
|
|
"""Save complete checkpoint (model + optimizer + scheduler)."""
|
|
Path(checkpoint_dir).mkdir(exist_ok=True)
|
|
|
|
checkpoint = {
|
|
'epoch': epoch,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'val_loss': val_loss,
|
|
'train_losses': self.train_losses,
|
|
'val_losses': self.val_losses,
|
|
}
|
|
|
|
# Save last checkpoint
|
|
torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_latest.pt')
|
|
|
|
# Save best checkpoint
|
|
if val_loss < self.best_val_loss:
|
|
self.best_val_loss = val_loss
|
|
torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_best.pt')
|
|
logger.info(f"New best validation loss: {val_loss:.4f}")
|
|
|
|
def load_checkpoint(self, checkpoint_path):
|
|
"""Load checkpoint and resume training correctly."""
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
# ✅ CRITICAL ORDER: Load model, optimizer, scheduler (in that order)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
# Restore metrics history
|
|
self.train_losses = checkpoint['train_losses']
|
|
self.val_losses = checkpoint['val_losses']
|
|
self.best_val_loss = min(self.val_losses) if self.val_losses else float('inf')
|
|
|
|
epoch = checkpoint['epoch']
|
|
logger.info(f"Loaded checkpoint from epoch {epoch}")
|
|
return epoch
|
|
|
|
def train(self, train_loader, val_loader, num_epochs, checkpoint_dir='checkpoints'):
|
|
"""Full training loop with validation and checkpointing."""
|
|
start_epoch = 0
|
|
|
|
# Try to resume from checkpoint if it exists
|
|
checkpoint_path = f'{checkpoint_dir}/checkpoint_latest.pt'
|
|
if Path(checkpoint_path).exists():
|
|
start_epoch = self.load_checkpoint(checkpoint_path)
|
|
logger.info(f"Resuming training from epoch {start_epoch}")
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
try:
|
|
# Train for one epoch
|
|
train_loss = self.train_epoch(train_loader)
|
|
self.train_losses.append(train_loss)
|
|
|
|
# ✅ CRITICAL: Validate after every epoch
|
|
val_loss = self.validate_epoch(val_loader)
|
|
self.val_losses.append(val_loss)
|
|
|
|
# Step scheduler (after epoch)
|
|
self.scheduler.step()
|
|
|
|
# Log metrics
|
|
logger.info(
|
|
f"Epoch {epoch}: train_loss={train_loss:.4f}, "
|
|
f"val_loss={val_loss:.4f}, lr={self.optimizer.param_groups[0]['lr']:.2e}"
|
|
)
|
|
|
|
# Checkpoint every epoch
|
|
self.save_checkpoint(epoch, val_loss, checkpoint_dir)
|
|
|
|
# Early stopping (optional)
|
|
if val_loss < self.best_val_loss:
|
|
self.epochs_without_improvement = 0
|
|
else:
|
|
self.epochs_without_improvement += 1
|
|
if self.epochs_without_improvement >= 10:
|
|
logger.info(f"Early stopping at epoch {epoch}")
|
|
break
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training interrupted by user")
|
|
self.save_checkpoint(epoch, val_loss, checkpoint_dir)
|
|
break
|
|
except RuntimeError as e:
|
|
logger.error(f"Error in epoch {epoch}: {e}")
|
|
raise
|
|
|
|
logger.info("Training complete")
|
|
return self.model
|
|
```
|
|
|
|
### 2. Data Split: Train/Val/Test Separation (CRITICAL)
|
|
|
|
```python
|
|
from sklearn.model_selection import train_test_split
|
|
from torch.utils.data import Subset, Dataset
|
|
|
|
# ✅ CORRECT: Proper three-way split with NO data leakage
|
|
class DataSplitter:
|
|
"""Ensures clean train/val/test splits without data leakage."""
|
|
|
|
@staticmethod
|
|
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
|
|
"""
|
|
Split dataset into train/val/test.
|
|
|
|
CRITICAL: Split indices first, then create loaders.
|
|
This prevents any data leakage.
|
|
"""
|
|
assert train_ratio + val_ratio + test_ratio == 1.0
|
|
|
|
n = len(dataset)
|
|
indices = list(range(n))
|
|
|
|
# First split: train vs (val + test)
|
|
train_size = int(train_ratio * n)
|
|
train_indices = indices[:train_size]
|
|
remaining_indices = indices[train_size:]
|
|
|
|
# Second split: val vs test
|
|
remaining_size = len(remaining_indices)
|
|
val_size = int(val_ratio / (val_ratio + test_ratio) * remaining_size)
|
|
val_indices = remaining_indices[:val_size]
|
|
test_indices = remaining_indices[val_size:]
|
|
|
|
# Create subset datasets (same transforms, different data)
|
|
train_dataset = Subset(dataset, train_indices)
|
|
val_dataset = Subset(dataset, val_indices)
|
|
test_dataset = Subset(dataset, test_indices)
|
|
|
|
logger.info(
|
|
f"Dataset split: train={len(train_dataset)}, "
|
|
f"val={len(val_dataset)}, test={len(test_dataset)}"
|
|
)
|
|
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
# Usage
|
|
train_dataset, val_dataset, test_dataset = DataSplitter.split_dataset(
|
|
full_dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15
|
|
)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
|
|
|
# ✅ CRITICAL: Validate that splits are actually different
|
|
print(f"Train samples: {len(train_loader.dataset)}")
|
|
print(f"Val samples: {len(val_loader.dataset)}")
|
|
print(f"Test samples: {len(test_loader.dataset)}")
|
|
|
|
# ✅ CRITICAL: Never mix splits (don't re-shuffle or combine)
|
|
```
|
|
|
|
### 3. Monitoring and Logging (Reproducibility)
|
|
|
|
```python
|
|
import json
|
|
from datetime import datetime
|
|
|
|
class TrainingMonitor:
|
|
"""Track all metrics for reproducibility and debugging."""
|
|
|
|
def __init__(self, log_dir='logs'):
|
|
self.log_dir = Path(log_dir)
|
|
self.log_dir.mkdir(exist_ok=True)
|
|
|
|
# Metrics to track
|
|
self.metrics = {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'epochs': [],
|
|
'train_losses': [],
|
|
'val_losses': [],
|
|
'learning_rates': [],
|
|
'gradient_norms': [],
|
|
'batch_times': [],
|
|
}
|
|
|
|
def log_epoch(self, epoch, train_loss, val_loss, lr, gradient_norm=None, batch_time=None):
|
|
"""Log metrics for one epoch."""
|
|
self.metrics['epochs'].append(epoch)
|
|
self.metrics['train_losses'].append(train_loss)
|
|
self.metrics['val_losses'].append(val_loss)
|
|
self.metrics['learning_rates'].append(lr)
|
|
if gradient_norm is not None:
|
|
self.metrics['gradient_norms'].append(gradient_norm)
|
|
if batch_time is not None:
|
|
self.metrics['batch_times'].append(batch_time)
|
|
|
|
def save_metrics(self):
|
|
"""Save metrics to JSON for post-training analysis."""
|
|
metrics_path = self.log_dir / f'metrics_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
|
|
with open(metrics_path, 'w') as f:
|
|
json.dump(self.metrics, f, indent=2)
|
|
logger.info(f"Metrics saved to {metrics_path}")
|
|
|
|
def plot_metrics(self):
|
|
"""Plot training curves."""
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
|
|
# Loss curves
|
|
axes[0, 0].plot(self.metrics['epochs'], self.metrics['train_losses'], label='Train')
|
|
axes[0, 0].plot(self.metrics['epochs'], self.metrics['val_losses'], label='Val')
|
|
axes[0, 0].set_xlabel('Epoch')
|
|
axes[0, 0].set_ylabel('Loss')
|
|
axes[0, 0].legend()
|
|
axes[0, 0].set_title('Training and Validation Loss')
|
|
|
|
# Learning rate schedule
|
|
axes[0, 1].plot(self.metrics['epochs'], self.metrics['learning_rates'])
|
|
axes[0, 1].set_xlabel('Epoch')
|
|
axes[0, 1].set_ylabel('Learning Rate')
|
|
axes[0, 1].set_title('Learning Rate Schedule')
|
|
axes[0, 1].set_yscale('log')
|
|
|
|
# Gradient norms (if available)
|
|
if self.metrics['gradient_norms']:
|
|
axes[1, 0].plot(self.metrics['epochs'], self.metrics['gradient_norms'])
|
|
axes[1, 0].set_xlabel('Epoch')
|
|
axes[1, 0].set_ylabel('Gradient Norm')
|
|
axes[1, 0].set_title('Gradient Norms')
|
|
|
|
# Batch times (if available)
|
|
if self.metrics['batch_times']:
|
|
axes[1, 1].plot(self.metrics['epochs'], self.metrics['batch_times'])
|
|
axes[1, 1].set_xlabel('Epoch')
|
|
axes[1, 1].set_ylabel('Time (seconds)')
|
|
axes[1, 1].set_title('Batch Processing Time')
|
|
|
|
plt.tight_layout()
|
|
plot_path = self.log_dir / f'training_curves_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'
|
|
plt.savefig(plot_path)
|
|
logger.info(f"Plot saved to {plot_path}")
|
|
```
|
|
|
|
### 4. Checkpointing and Resuming (Complete State)
|
|
|
|
```python
|
|
class CheckpointManager:
|
|
"""Properly save and load ALL training state."""
|
|
|
|
def __init__(self, checkpoint_dir='checkpoints'):
|
|
self.checkpoint_dir = Path(checkpoint_dir)
|
|
self.checkpoint_dir.mkdir(exist_ok=True)
|
|
|
|
def save_full_checkpoint(self, epoch, model, optimizer, scheduler, metrics, path_suffix=''):
|
|
"""Save COMPLETE state for resuming training."""
|
|
checkpoint = {
|
|
# Model and optimizer state
|
|
'epoch': epoch,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'scheduler_state_dict': scheduler.state_dict(),
|
|
|
|
# Training metrics (for monitoring)
|
|
'train_losses': metrics['train_losses'],
|
|
'val_losses': metrics['val_losses'],
|
|
'learning_rates': metrics['learning_rates'],
|
|
|
|
# Timestamp for recovery
|
|
'timestamp': datetime.now().isoformat(),
|
|
}
|
|
|
|
# Save as latest
|
|
latest_path = self.checkpoint_dir / f'checkpoint_latest{path_suffix}.pt'
|
|
torch.save(checkpoint, latest_path)
|
|
|
|
# Save periodically (every 10 epochs)
|
|
if epoch % 10 == 0:
|
|
periodic_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt'
|
|
torch.save(checkpoint, periodic_path)
|
|
|
|
logger.info(f"Checkpoint saved: {latest_path}")
|
|
return latest_path
|
|
|
|
def load_full_checkpoint(self, model, optimizer, scheduler, checkpoint_path):
|
|
"""Load COMPLETE state correctly."""
|
|
if not Path(checkpoint_path).exists():
|
|
logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
|
return 0, None
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
# ✅ CRITICAL ORDER: Model first, then optimizer, then scheduler
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
epoch = checkpoint['epoch']
|
|
metrics = {
|
|
'train_losses': checkpoint.get('train_losses', []),
|
|
'val_losses': checkpoint.get('val_losses', []),
|
|
'learning_rates': checkpoint.get('learning_rates', []),
|
|
}
|
|
|
|
logger.info(
|
|
f"Loaded checkpoint from epoch {epoch}, "
|
|
f"saved at {checkpoint.get('timestamp', 'unknown')}"
|
|
)
|
|
return epoch, metrics
|
|
|
|
def get_best_checkpoint(self):
|
|
"""Find checkpoint with best validation loss."""
|
|
checkpoints = list(self.checkpoint_dir.glob('checkpoint_epoch_*.pt'))
|
|
if not checkpoints:
|
|
return None
|
|
|
|
best_loss = float('inf')
|
|
best_path = None
|
|
|
|
for ckpt_path in checkpoints:
|
|
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
|
val_losses = checkpoint.get('val_losses', [])
|
|
if val_losses and min(val_losses) < best_loss:
|
|
best_loss = min(val_losses)
|
|
best_path = ckpt_path
|
|
|
|
return best_path
|
|
```
|
|
|
|
### 5. Memory Management (Prevent Leaks)
|
|
|
|
```python
|
|
class MemoryManager:
|
|
"""Prevent out-of-memory errors during long training."""
|
|
|
|
def __init__(self, device='cuda'):
|
|
self.device = device
|
|
|
|
def clear_cache(self):
|
|
"""Clear GPU cache between epochs."""
|
|
if self.device.startswith('cuda'):
|
|
torch.cuda.empty_cache()
|
|
# Optional: clear CUDA graphs
|
|
torch.cuda.synchronize()
|
|
|
|
def check_memory(self):
|
|
"""Log GPU memory usage."""
|
|
if self.device.startswith('cuda'):
|
|
allocated = torch.cuda.memory_allocated() / 1e9
|
|
reserved = torch.cuda.memory_reserved() / 1e9
|
|
logger.info(f"GPU memory - allocated: {allocated:.2f}GB, reserved: {reserved:.2f}GB")
|
|
|
|
def training_loop_with_memory_management(self, model, train_loader, optimizer, criterion):
|
|
"""Training loop with proper memory management."""
|
|
model.train()
|
|
total_loss = 0.0
|
|
|
|
for batch_idx, (data, target) in enumerate(train_loader):
|
|
data, target = data.to(self.device), target.to(self.device)
|
|
|
|
# Forward and backward
|
|
optimizer.zero_grad()
|
|
output = model(data)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
# ✅ Clear temporary tensors (data and target go out of scope)
|
|
# ✅ Don't hold onto loss or output after using them
|
|
|
|
# Periodically check memory
|
|
if batch_idx % 100 == 0:
|
|
self.check_memory()
|
|
|
|
# Clear cache between epochs
|
|
self.clear_cache()
|
|
|
|
return total_loss / len(train_loader)
|
|
```
|
|
|
|
|
|
## Error Handling and Recovery
|
|
|
|
```python
|
|
class RobustTrainingLoop:
|
|
"""Training loop with proper error handling."""
|
|
|
|
def train_with_error_handling(self, model, train_loader, val_loader, optimizer,
|
|
scheduler, criterion, num_epochs, checkpoint_dir):
|
|
"""Training with error recovery."""
|
|
checkpoint_manager = CheckpointManager(checkpoint_dir)
|
|
memory_manager = MemoryManager()
|
|
|
|
# Resume from last checkpoint if available
|
|
start_epoch, metrics = checkpoint_manager.load_full_checkpoint(
|
|
model, optimizer, scheduler, f'{checkpoint_dir}/checkpoint_latest.pt'
|
|
)
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
try:
|
|
# Train
|
|
train_loss = self.train_epoch(model, train_loader, optimizer, criterion)
|
|
|
|
# Validate
|
|
val_loss = self.validate_epoch(model, val_loader, criterion)
|
|
|
|
# Update scheduler
|
|
scheduler.step()
|
|
|
|
# Log
|
|
logger.info(
|
|
f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, "
|
|
f"lr={optimizer.param_groups[0]['lr']:.2e}"
|
|
)
|
|
|
|
# Checkpoint
|
|
checkpoint_manager.save_full_checkpoint(
|
|
epoch, model, optimizer, scheduler,
|
|
{'train_losses': [train_loss], 'val_losses': [val_loss]}
|
|
)
|
|
|
|
# Memory management
|
|
memory_manager.clear_cache()
|
|
|
|
except KeyboardInterrupt:
|
|
logger.warning("Training interrupted - checkpoint saved")
|
|
checkpoint_manager.save_full_checkpoint(
|
|
epoch, model, optimizer, scheduler,
|
|
{'train_losses': [train_loss], 'val_losses': [val_loss]}
|
|
)
|
|
break
|
|
|
|
except RuntimeError as e:
|
|
if 'out of memory' in str(e).lower():
|
|
logger.error("Out of memory error")
|
|
memory_manager.clear_cache()
|
|
# Try to continue (reduce batch size in real scenario)
|
|
raise
|
|
else:
|
|
logger.error(f"Runtime error: {e}")
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in epoch {epoch}: {e}")
|
|
checkpoint_manager.save_full_checkpoint(
|
|
epoch, model, optimizer, scheduler,
|
|
{'train_losses': [train_loss], 'val_losses': [val_loss]}
|
|
)
|
|
raise
|
|
|
|
return model
|
|
```
|
|
|
|
|
|
## Common Pitfalls and How to Avoid Them
|
|
|
|
### Pitfall 1: Validating on Training Data
|
|
```python
|
|
# ❌ WRONG
|
|
val_loader = train_loader # Same loader!
|
|
|
|
# ✅ CORRECT
|
|
train_dataset, val_dataset = split_dataset(full_dataset)
|
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
|
```
|
|
|
|
### Pitfall 2: Missing Optimizer State in Checkpoint
|
|
```python
|
|
# ❌ WRONG
|
|
torch.save({'model': model.state_dict()}, 'ckpt.pt')
|
|
|
|
# ✅ CORRECT
|
|
torch.save({
|
|
'model': model.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'scheduler': scheduler.state_dict(),
|
|
}, 'ckpt.pt')
|
|
```
|
|
|
|
### Pitfall 3: Not Validating During Training
|
|
```python
|
|
# ❌ WRONG
|
|
for epoch in range(100):
|
|
train_epoch()
|
|
final_val = evaluate() # Only at the end!
|
|
|
|
# ✅ CORRECT
|
|
for epoch in range(100):
|
|
train_epoch()
|
|
validate_epoch() # After every epoch
|
|
```
|
|
|
|
### Pitfall 4: Holding Onto Tensor References
|
|
```python
|
|
# ❌ WRONG
|
|
all_losses = []
|
|
for data, target in loader:
|
|
loss = criterion(model(data), target)
|
|
all_losses.append(loss) # Memory leak!
|
|
|
|
# ✅ CORRECT
|
|
total_loss = 0.0
|
|
for data, target in loader:
|
|
loss = criterion(model(data), target)
|
|
total_loss += loss.item() # Scalar value
|
|
```
|
|
|
|
### Pitfall 5: Forgetting torch.no_grad() in Validation
|
|
```python
|
|
# ❌ WRONG
|
|
model.eval()
|
|
for data, target in val_loader:
|
|
output = model(data) # Gradients still computed!
|
|
loss = criterion(output, target)
|
|
|
|
# ✅ CORRECT
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for data, target in val_loader:
|
|
output = model(data) # No gradients
|
|
loss = criterion(output, target)
|
|
```
|
|
|
|
### Pitfall 6: Resetting Scheduler on Resume
|
|
```python
|
|
# ❌ WRONG
|
|
checkpoint = torch.load('ckpt.pt')
|
|
model.load_state_dict(checkpoint['model'])
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=100) # Fresh scheduler!
|
|
# Now at epoch 50, scheduler thinks it's epoch 0
|
|
|
|
# ✅ CORRECT
|
|
checkpoint = torch.load('ckpt.pt')
|
|
model.load_state_dict(checkpoint['model'])
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
scheduler.load_state_dict(checkpoint['scheduler']) # Resume scheduler state
|
|
```
|
|
|
|
### Pitfall 7: Not Handling Early Stopping Correctly
|
|
```python
|
|
# ❌ WRONG
|
|
best_loss = float('inf')
|
|
for epoch in range(100):
|
|
val_loss = validate()
|
|
if val_loss < best_loss:
|
|
best_loss = val_loss
|
|
# No checkpoint! Can't recover best model
|
|
|
|
# ✅ CORRECT
|
|
best_loss = float('inf')
|
|
patience = 10
|
|
patience_counter = 0
|
|
for epoch in range(100):
|
|
val_loss = validate()
|
|
if val_loss < best_loss:
|
|
best_loss = val_loss
|
|
patience_counter = 0
|
|
save_checkpoint(model, optimizer, scheduler, epoch) # Save best
|
|
else:
|
|
patience_counter += 1
|
|
if patience_counter >= patience:
|
|
break # Stop early
|
|
```
|
|
|
|
### Pitfall 8: Mixing Train and Validation Mode
|
|
```python
|
|
# ❌ WRONG
|
|
for epoch in range(100):
|
|
for data, target in train_loader:
|
|
output = model(data) # Is model in train or eval mode?
|
|
loss = criterion(output, target)
|
|
|
|
# ✅ CORRECT
|
|
model.train()
|
|
for epoch in range(100):
|
|
for data, target in train_loader:
|
|
output = model(data) # Definitely in train mode
|
|
loss = criterion(output, target)
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for data, target in val_loader:
|
|
output = model(data) # Definitely in eval mode
|
|
```
|
|
|
|
### Pitfall 9: Loading Checkpoint on Wrong Device
|
|
```python
|
|
# ❌ WRONG
|
|
checkpoint = torch.load('ckpt.pt') # Loads on GPU if saved on GPU
|
|
model.load_state_dict(checkpoint['model']) # Might be on wrong device
|
|
|
|
# ✅ CORRECT
|
|
checkpoint = torch.load('ckpt.pt', map_location='cuda:0') # Specify device
|
|
model.load_state_dict(checkpoint['model'])
|
|
model.to('cuda:0') # Move to device
|
|
```
|
|
|
|
### Pitfall 10: Not Clearing GPU Cache
|
|
```python
|
|
# ❌ WRONG
|
|
for epoch in range(100):
|
|
train_epoch()
|
|
validate_epoch()
|
|
# GPU cache growing every epoch
|
|
|
|
# ✅ CORRECT
|
|
for epoch in range(100):
|
|
train_epoch()
|
|
validate_epoch()
|
|
torch.cuda.empty_cache() # Clear cache
|
|
```
|
|
|
|
|
|
## Integration with Optimization Techniques
|
|
|
|
### Complete Training Loop with All Techniques
|
|
|
|
```python
|
|
class FullyOptimizedTrainingLoop:
|
|
"""Integrates: gradient clipping, mixed precision, learning rate scheduling."""
|
|
|
|
def train_with_all_techniques(self, model, train_loader, val_loader,
|
|
num_epochs, checkpoint_dir='checkpoints'):
|
|
"""Training with all optimization techniques integrated."""
|
|
|
|
# Setup
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = model.to(device)
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
# Mixed precision (if using AMP)
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
# Training loop
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
total_loss = 0.0
|
|
|
|
for data, target in train_loader:
|
|
data, target = data.to(device), target.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Mixed precision forward pass
|
|
with torch.autocast('cuda'):
|
|
output = model(data)
|
|
loss = criterion(output, target)
|
|
|
|
# Gradient scaling for mixed precision
|
|
scaler.scale(loss).backward()
|
|
|
|
# Gradient clipping (unscale first!)
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
|
|
# Optimizer step
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
total_loss += loss.item()
|
|
|
|
train_loss = total_loss / len(train_loader)
|
|
|
|
# Validation
|
|
model.eval()
|
|
val_loss = 0.0
|
|
with torch.no_grad():
|
|
for data, target in val_loader:
|
|
data, target = data.to(device), target.to(device)
|
|
output = model(data)
|
|
val_loss += criterion(output, target).item()
|
|
|
|
val_loss /= len(val_loader)
|
|
|
|
# Scheduler step
|
|
scheduler.step()
|
|
|
|
logger.info(
|
|
f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, "
|
|
f"lr={optimizer.param_groups[0]['lr']:.2e}"
|
|
)
|
|
|
|
return model
|
|
```
|
|
|
|
|
|
## Rationalization Table: When to Deviate from Standard
|
|
|
|
| Situation | Standard Practice | Deviation | Rationale |
|
|
|-----------|-------------------|-----------|-----------|
|
|
| Validate only at end | Validate every epoch | ✗ Never | Can't detect overfitting |
|
|
| Save only model | Save model + optimizer + scheduler | ✗ Never | Resume training breaks |
|
|
| Mixed train/val | Separate datasets completely | ✗ Never | Data leakage and false metrics |
|
|
| Constant batch size | Fix batch size for reproducibility | ✓ Sometimes | May need dynamic batching for memory |
|
|
| Single LR | Use scheduler | ✓ Sometimes | <10 epoch training or hyperparameter search |
|
|
| No early stopping | Implement early stopping | ✓ Sometimes | If training time unlimited |
|
|
| Log every batch | Log every 10-100 batches | ✓ Often | Reduces I/O overhead |
|
|
| GPU cache every epoch | Clear GPU cache periodically | ✓ Sometimes | Only if OOM issues |
|
|
|
|
|
|
## Red Flags: Immediate Warning Signs
|
|
|
|
1. **Training loss much lower than validation loss** (>2x) → Overfitting
|
|
2. **Loss spikes on resume** → Optimizer state not loaded
|
|
3. **GPU memory grows over time** → Memory leak, likely tensor accumulation
|
|
4. **Validation never runs** → Check if validation is in loop
|
|
5. **Best model not saved** → Check checkpoint logic
|
|
6. **Different results on resume** → Scheduler not loaded
|
|
7. **Early stopping not working** → Checkpoint not at best model
|
|
8. **OOM during training** → Clear GPU cache, check for accumulated tensors
|
|
|
|
|
|
## Testing Your Training Loop
|
|
|
|
```python
|
|
def test_training_loop():
|
|
"""Quick test to verify training loop is correct."""
|
|
|
|
# Create dummy data
|
|
X_train = torch.randn(100, 10)
|
|
y_train = torch.randint(0, 2, (100,))
|
|
X_val = torch.randn(20, 10)
|
|
y_val = torch.randint(0, 2, (20,))
|
|
|
|
train_loader = DataLoader(
|
|
list(zip(X_train, y_train)), batch_size=16
|
|
)
|
|
val_loader = DataLoader(
|
|
list(zip(X_val, y_val)), batch_size=16
|
|
)
|
|
|
|
# Simple model
|
|
model = nn.Sequential(
|
|
nn.Linear(10, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 2)
|
|
)
|
|
|
|
# Training
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
loop = TrainingLoop(model, optimizer, scheduler, criterion)
|
|
|
|
# Should complete without errors
|
|
loop.train(train_loader, val_loader, num_epochs=5, checkpoint_dir='test_ckpts')
|
|
|
|
# Check outputs
|
|
assert len(loop.train_losses) == 5
|
|
assert len(loop.val_losses) == 5
|
|
assert all(isinstance(l, float) for l in loop.train_losses)
|
|
|
|
print("✓ Training loop test passed")
|
|
|
|
if __name__ == '__main__':
|
|
test_training_loop()
|
|
```
|
|
|