# Training Loop Architecture ## Overview **Core Principle:** A properly structured training loop is the foundation of all successful deep learning projects. Success requires: (1) correct train/val/test data separation, (2) validation after EVERY epoch (not just once), (3) complete checkpoint state (model + optimizer + scheduler), (4) comprehensive logging/monitoring, and (5) graceful error handling. Poor loop structure causes: silent overfitting, broken resume functionality, undetectable training issues, and memory leaks. Training loop failures manifest as: overfitting with good metrics, crashes on resume, unexplained loss spikes, or out-of-memory errors. These stem from misunderstanding when validation runs, what state must be saved, or how to manage GPU memory. Systematic architecture beats trial-and-error fixes. ## When to Use **Use this skill when:** - Implementing a new training loop from scratch - Training loop is crashing unexpectedly - Can't resume training from checkpoint correctly - Model overfits but validation metrics look good - Out-of-memory errors during training - Unsure about train/val/test data split - Need to monitor training progress properly - Implementing early stopping or checkpoint selection - Training loops show loss spikes or divergence on resume - Adding logging/monitoring to training **Don't use when:** - Debugging single backward pass (use gradient-management skill) - Tuning learning rate (use learning-rate-scheduling skill) - Fixing specific loss function (use loss-functions-and-objectives skill) - Data loading issues (use data-augmentation-strategies skill) **Symptoms triggering this skill:** - "Training loss decreases but validation loss increases (overfitting)" - "Training crashes when resuming from checkpoint" - "Out of memory errors after epoch 20" - "I validated on training data and didn't realize" - "Can't detect overfitting because I don't validate" - "Training loss spikes when resuming" - "My checkpoint doesn't load correctly" ## Complete Training Loop Structure ### 1. The Standard Training Loop (The Reference) ```python import torch import torch.nn as nn from torch.utils.data import DataLoader import logging from pathlib import Path # Setup logging (ALWAYS do this first) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class TrainingLoop: """Complete training loop with validation, checkpointing, and monitoring.""" def __init__(self, model, optimizer, scheduler, criterion, device='cuda'): self.model = model self.optimizer = optimizer self.scheduler = scheduler self.criterion = criterion self.device = device # Tracking metrics self.train_losses = [] self.val_losses = [] self.best_val_loss = float('inf') self.epochs_without_improvement = 0 def train_epoch(self, train_loader): """Train for one epoch.""" self.model.train() total_loss = 0.0 num_batches = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(self.device), target.to(self.device) # Forward pass self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) # Backward pass with gradient clipping (if needed) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() # Accumulate loss total_loss += loss.item() num_batches += 1 # Log progress every 10 batches if batch_idx % 10 == 0: logger.debug(f"Batch {batch_idx}: loss={loss.item():.4f}") avg_loss = total_loss / num_batches return avg_loss def validate_epoch(self, val_loader): """Validate on validation set (AFTER each epoch, not during).""" self.model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): # ✅ CRITICAL: No gradients during validation for data, target in val_loader: data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) total_loss += loss.item() num_batches += 1 avg_loss = total_loss / num_batches return avg_loss def save_checkpoint(self, epoch, val_loss, checkpoint_dir='checkpoints'): """Save complete checkpoint (model + optimizer + scheduler).""" Path(checkpoint_dir).mkdir(exist_ok=True) checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'val_loss': val_loss, 'train_losses': self.train_losses, 'val_losses': self.val_losses, } # Save last checkpoint torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_latest.pt') # Save best checkpoint if val_loss < self.best_val_loss: self.best_val_loss = val_loss torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_best.pt') logger.info(f"New best validation loss: {val_loss:.4f}") def load_checkpoint(self, checkpoint_path): """Load checkpoint and resume training correctly.""" checkpoint = torch.load(checkpoint_path, map_location=self.device) # ✅ CRITICAL ORDER: Load model, optimizer, scheduler (in that order) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # Restore metrics history self.train_losses = checkpoint['train_losses'] self.val_losses = checkpoint['val_losses'] self.best_val_loss = min(self.val_losses) if self.val_losses else float('inf') epoch = checkpoint['epoch'] logger.info(f"Loaded checkpoint from epoch {epoch}") return epoch def train(self, train_loader, val_loader, num_epochs, checkpoint_dir='checkpoints'): """Full training loop with validation and checkpointing.""" start_epoch = 0 # Try to resume from checkpoint if it exists checkpoint_path = f'{checkpoint_dir}/checkpoint_latest.pt' if Path(checkpoint_path).exists(): start_epoch = self.load_checkpoint(checkpoint_path) logger.info(f"Resuming training from epoch {start_epoch}") for epoch in range(start_epoch, num_epochs): try: # Train for one epoch train_loss = self.train_epoch(train_loader) self.train_losses.append(train_loss) # ✅ CRITICAL: Validate after every epoch val_loss = self.validate_epoch(val_loader) self.val_losses.append(val_loss) # Step scheduler (after epoch) self.scheduler.step() # Log metrics logger.info( f"Epoch {epoch}: train_loss={train_loss:.4f}, " f"val_loss={val_loss:.4f}, lr={self.optimizer.param_groups[0]['lr']:.2e}" ) # Checkpoint every epoch self.save_checkpoint(epoch, val_loss, checkpoint_dir) # Early stopping (optional) if val_loss < self.best_val_loss: self.epochs_without_improvement = 0 else: self.epochs_without_improvement += 1 if self.epochs_without_improvement >= 10: logger.info(f"Early stopping at epoch {epoch}") break except KeyboardInterrupt: logger.info("Training interrupted by user") self.save_checkpoint(epoch, val_loss, checkpoint_dir) break except RuntimeError as e: logger.error(f"Error in epoch {epoch}: {e}") raise logger.info("Training complete") return self.model ``` ### 2. Data Split: Train/Val/Test Separation (CRITICAL) ```python from sklearn.model_selection import train_test_split from torch.utils.data import Subset, Dataset # ✅ CORRECT: Proper three-way split with NO data leakage class DataSplitter: """Ensures clean train/val/test splits without data leakage.""" @staticmethod def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42): """ Split dataset into train/val/test. CRITICAL: Split indices first, then create loaders. This prevents any data leakage. """ assert train_ratio + val_ratio + test_ratio == 1.0 n = len(dataset) indices = list(range(n)) # First split: train vs (val + test) train_size = int(train_ratio * n) train_indices = indices[:train_size] remaining_indices = indices[train_size:] # Second split: val vs test remaining_size = len(remaining_indices) val_size = int(val_ratio / (val_ratio + test_ratio) * remaining_size) val_indices = remaining_indices[:val_size] test_indices = remaining_indices[val_size:] # Create subset datasets (same transforms, different data) train_dataset = Subset(dataset, train_indices) val_dataset = Subset(dataset, val_indices) test_dataset = Subset(dataset, test_indices) logger.info( f"Dataset split: train={len(train_dataset)}, " f"val={len(val_dataset)}, test={len(test_dataset)}" ) return train_dataset, val_dataset, test_dataset # Usage train_dataset, val_dataset, test_dataset = DataSplitter.split_dataset( full_dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15 ) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # ✅ CRITICAL: Validate that splits are actually different print(f"Train samples: {len(train_loader.dataset)}") print(f"Val samples: {len(val_loader.dataset)}") print(f"Test samples: {len(test_loader.dataset)}") # ✅ CRITICAL: Never mix splits (don't re-shuffle or combine) ``` ### 3. Monitoring and Logging (Reproducibility) ```python import json from datetime import datetime class TrainingMonitor: """Track all metrics for reproducibility and debugging.""" def __init__(self, log_dir='logs'): self.log_dir = Path(log_dir) self.log_dir.mkdir(exist_ok=True) # Metrics to track self.metrics = { 'timestamp': datetime.now().isoformat(), 'epochs': [], 'train_losses': [], 'val_losses': [], 'learning_rates': [], 'gradient_norms': [], 'batch_times': [], } def log_epoch(self, epoch, train_loss, val_loss, lr, gradient_norm=None, batch_time=None): """Log metrics for one epoch.""" self.metrics['epochs'].append(epoch) self.metrics['train_losses'].append(train_loss) self.metrics['val_losses'].append(val_loss) self.metrics['learning_rates'].append(lr) if gradient_norm is not None: self.metrics['gradient_norms'].append(gradient_norm) if batch_time is not None: self.metrics['batch_times'].append(batch_time) def save_metrics(self): """Save metrics to JSON for post-training analysis.""" metrics_path = self.log_dir / f'metrics_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json' with open(metrics_path, 'w') as f: json.dump(self.metrics, f, indent=2) logger.info(f"Metrics saved to {metrics_path}") def plot_metrics(self): """Plot training curves.""" import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 2, figsize=(12, 8)) # Loss curves axes[0, 0].plot(self.metrics['epochs'], self.metrics['train_losses'], label='Train') axes[0, 0].plot(self.metrics['epochs'], self.metrics['val_losses'], label='Val') axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].legend() axes[0, 0].set_title('Training and Validation Loss') # Learning rate schedule axes[0, 1].plot(self.metrics['epochs'], self.metrics['learning_rates']) axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Learning Rate') axes[0, 1].set_title('Learning Rate Schedule') axes[0, 1].set_yscale('log') # Gradient norms (if available) if self.metrics['gradient_norms']: axes[1, 0].plot(self.metrics['epochs'], self.metrics['gradient_norms']) axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('Gradient Norm') axes[1, 0].set_title('Gradient Norms') # Batch times (if available) if self.metrics['batch_times']: axes[1, 1].plot(self.metrics['epochs'], self.metrics['batch_times']) axes[1, 1].set_xlabel('Epoch') axes[1, 1].set_ylabel('Time (seconds)') axes[1, 1].set_title('Batch Processing Time') plt.tight_layout() plot_path = self.log_dir / f'training_curves_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png' plt.savefig(plot_path) logger.info(f"Plot saved to {plot_path}") ``` ### 4. Checkpointing and Resuming (Complete State) ```python class CheckpointManager: """Properly save and load ALL training state.""" def __init__(self, checkpoint_dir='checkpoints'): self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(exist_ok=True) def save_full_checkpoint(self, epoch, model, optimizer, scheduler, metrics, path_suffix=''): """Save COMPLETE state for resuming training.""" checkpoint = { # Model and optimizer state 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), # Training metrics (for monitoring) 'train_losses': metrics['train_losses'], 'val_losses': metrics['val_losses'], 'learning_rates': metrics['learning_rates'], # Timestamp for recovery 'timestamp': datetime.now().isoformat(), } # Save as latest latest_path = self.checkpoint_dir / f'checkpoint_latest{path_suffix}.pt' torch.save(checkpoint, latest_path) # Save periodically (every 10 epochs) if epoch % 10 == 0: periodic_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt' torch.save(checkpoint, periodic_path) logger.info(f"Checkpoint saved: {latest_path}") return latest_path def load_full_checkpoint(self, model, optimizer, scheduler, checkpoint_path): """Load COMPLETE state correctly.""" if not Path(checkpoint_path).exists(): logger.warning(f"Checkpoint not found: {checkpoint_path}") return 0, None checkpoint = torch.load(checkpoint_path, map_location='cpu') # ✅ CRITICAL ORDER: Model first, then optimizer, then scheduler model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) epoch = checkpoint['epoch'] metrics = { 'train_losses': checkpoint.get('train_losses', []), 'val_losses': checkpoint.get('val_losses', []), 'learning_rates': checkpoint.get('learning_rates', []), } logger.info( f"Loaded checkpoint from epoch {epoch}, " f"saved at {checkpoint.get('timestamp', 'unknown')}" ) return epoch, metrics def get_best_checkpoint(self): """Find checkpoint with best validation loss.""" checkpoints = list(self.checkpoint_dir.glob('checkpoint_epoch_*.pt')) if not checkpoints: return None best_loss = float('inf') best_path = None for ckpt_path in checkpoints: checkpoint = torch.load(ckpt_path, map_location='cpu') val_losses = checkpoint.get('val_losses', []) if val_losses and min(val_losses) < best_loss: best_loss = min(val_losses) best_path = ckpt_path return best_path ``` ### 5. Memory Management (Prevent Leaks) ```python class MemoryManager: """Prevent out-of-memory errors during long training.""" def __init__(self, device='cuda'): self.device = device def clear_cache(self): """Clear GPU cache between epochs.""" if self.device.startswith('cuda'): torch.cuda.empty_cache() # Optional: clear CUDA graphs torch.cuda.synchronize() def check_memory(self): """Log GPU memory usage.""" if self.device.startswith('cuda'): allocated = torch.cuda.memory_allocated() / 1e9 reserved = torch.cuda.memory_reserved() / 1e9 logger.info(f"GPU memory - allocated: {allocated:.2f}GB, reserved: {reserved:.2f}GB") def training_loop_with_memory_management(self, model, train_loader, optimizer, criterion): """Training loop with proper memory management.""" model.train() total_loss = 0.0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(self.device), target.to(self.device) # Forward and backward optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() # ✅ Clear temporary tensors (data and target go out of scope) # ✅ Don't hold onto loss or output after using them # Periodically check memory if batch_idx % 100 == 0: self.check_memory() # Clear cache between epochs self.clear_cache() return total_loss / len(train_loader) ``` ## Error Handling and Recovery ```python class RobustTrainingLoop: """Training loop with proper error handling.""" def train_with_error_handling(self, model, train_loader, val_loader, optimizer, scheduler, criterion, num_epochs, checkpoint_dir): """Training with error recovery.""" checkpoint_manager = CheckpointManager(checkpoint_dir) memory_manager = MemoryManager() # Resume from last checkpoint if available start_epoch, metrics = checkpoint_manager.load_full_checkpoint( model, optimizer, scheduler, f'{checkpoint_dir}/checkpoint_latest.pt' ) for epoch in range(start_epoch, num_epochs): try: # Train train_loss = self.train_epoch(model, train_loader, optimizer, criterion) # Validate val_loss = self.validate_epoch(model, val_loader, criterion) # Update scheduler scheduler.step() # Log logger.info( f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, " f"lr={optimizer.param_groups[0]['lr']:.2e}" ) # Checkpoint checkpoint_manager.save_full_checkpoint( epoch, model, optimizer, scheduler, {'train_losses': [train_loss], 'val_losses': [val_loss]} ) # Memory management memory_manager.clear_cache() except KeyboardInterrupt: logger.warning("Training interrupted - checkpoint saved") checkpoint_manager.save_full_checkpoint( epoch, model, optimizer, scheduler, {'train_losses': [train_loss], 'val_losses': [val_loss]} ) break except RuntimeError as e: if 'out of memory' in str(e).lower(): logger.error("Out of memory error") memory_manager.clear_cache() # Try to continue (reduce batch size in real scenario) raise else: logger.error(f"Runtime error: {e}") raise except Exception as e: logger.error(f"Unexpected error in epoch {epoch}: {e}") checkpoint_manager.save_full_checkpoint( epoch, model, optimizer, scheduler, {'train_losses': [train_loss], 'val_losses': [val_loss]} ) raise return model ``` ## Common Pitfalls and How to Avoid Them ### Pitfall 1: Validating on Training Data ```python # ❌ WRONG val_loader = train_loader # Same loader! # ✅ CORRECT train_dataset, val_dataset = split_dataset(full_dataset) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) ``` ### Pitfall 2: Missing Optimizer State in Checkpoint ```python # ❌ WRONG torch.save({'model': model.state_dict()}, 'ckpt.pt') # ✅ CORRECT torch.save({ 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }, 'ckpt.pt') ``` ### Pitfall 3: Not Validating During Training ```python # ❌ WRONG for epoch in range(100): train_epoch() final_val = evaluate() # Only at the end! # ✅ CORRECT for epoch in range(100): train_epoch() validate_epoch() # After every epoch ``` ### Pitfall 4: Holding Onto Tensor References ```python # ❌ WRONG all_losses = [] for data, target in loader: loss = criterion(model(data), target) all_losses.append(loss) # Memory leak! # ✅ CORRECT total_loss = 0.0 for data, target in loader: loss = criterion(model(data), target) total_loss += loss.item() # Scalar value ``` ### Pitfall 5: Forgetting torch.no_grad() in Validation ```python # ❌ WRONG model.eval() for data, target in val_loader: output = model(data) # Gradients still computed! loss = criterion(output, target) # ✅ CORRECT model.eval() with torch.no_grad(): for data, target in val_loader: output = model(data) # No gradients loss = criterion(output, target) ``` ### Pitfall 6: Resetting Scheduler on Resume ```python # ❌ WRONG checkpoint = torch.load('ckpt.pt') model.load_state_dict(checkpoint['model']) scheduler = CosineAnnealingLR(optimizer, T_max=100) # Fresh scheduler! # Now at epoch 50, scheduler thinks it's epoch 0 # ✅ CORRECT checkpoint = torch.load('ckpt.pt') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) # Resume scheduler state ``` ### Pitfall 7: Not Handling Early Stopping Correctly ```python # ❌ WRONG best_loss = float('inf') for epoch in range(100): val_loss = validate() if val_loss < best_loss: best_loss = val_loss # No checkpoint! Can't recover best model # ✅ CORRECT best_loss = float('inf') patience = 10 patience_counter = 0 for epoch in range(100): val_loss = validate() if val_loss < best_loss: best_loss = val_loss patience_counter = 0 save_checkpoint(model, optimizer, scheduler, epoch) # Save best else: patience_counter += 1 if patience_counter >= patience: break # Stop early ``` ### Pitfall 8: Mixing Train and Validation Mode ```python # ❌ WRONG for epoch in range(100): for data, target in train_loader: output = model(data) # Is model in train or eval mode? loss = criterion(output, target) # ✅ CORRECT model.train() for epoch in range(100): for data, target in train_loader: output = model(data) # Definitely in train mode loss = criterion(output, target) model.eval() with torch.no_grad(): for data, target in val_loader: output = model(data) # Definitely in eval mode ``` ### Pitfall 9: Loading Checkpoint on Wrong Device ```python # ❌ WRONG checkpoint = torch.load('ckpt.pt') # Loads on GPU if saved on GPU model.load_state_dict(checkpoint['model']) # Might be on wrong device # ✅ CORRECT checkpoint = torch.load('ckpt.pt', map_location='cuda:0') # Specify device model.load_state_dict(checkpoint['model']) model.to('cuda:0') # Move to device ``` ### Pitfall 10: Not Clearing GPU Cache ```python # ❌ WRONG for epoch in range(100): train_epoch() validate_epoch() # GPU cache growing every epoch # ✅ CORRECT for epoch in range(100): train_epoch() validate_epoch() torch.cuda.empty_cache() # Clear cache ``` ## Integration with Optimization Techniques ### Complete Training Loop with All Techniques ```python class FullyOptimizedTrainingLoop: """Integrates: gradient clipping, mixed precision, learning rate scheduling.""" def train_with_all_techniques(self, model, train_loader, val_loader, num_epochs, checkpoint_dir='checkpoints'): """Training with all optimization techniques integrated.""" # Setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) criterion = nn.CrossEntropyLoss() # Mixed precision (if using AMP) scaler = torch.cuda.amp.GradScaler() # Training loop for epoch in range(num_epochs): model.train() total_loss = 0.0 for data, target in train_loader: data, target = data.to(device), target.to(device) optimizer.zero_grad() # Mixed precision forward pass with torch.autocast('cuda'): output = model(data) loss = criterion(output, target) # Gradient scaling for mixed precision scaler.scale(loss).backward() # Gradient clipping (unscale first!) scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Optimizer step scaler.step(optimizer) scaler.update() total_loss += loss.item() train_loss = total_loss / len(train_loader) # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for data, target in val_loader: data, target = data.to(device), target.to(device) output = model(data) val_loss += criterion(output, target).item() val_loss /= len(val_loader) # Scheduler step scheduler.step() logger.info( f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, " f"lr={optimizer.param_groups[0]['lr']:.2e}" ) return model ``` ## Rationalization Table: When to Deviate from Standard | Situation | Standard Practice | Deviation | Rationale | |-----------|-------------------|-----------|-----------| | Validate only at end | Validate every epoch | ✗ Never | Can't detect overfitting | | Save only model | Save model + optimizer + scheduler | ✗ Never | Resume training breaks | | Mixed train/val | Separate datasets completely | ✗ Never | Data leakage and false metrics | | Constant batch size | Fix batch size for reproducibility | ✓ Sometimes | May need dynamic batching for memory | | Single LR | Use scheduler | ✓ Sometimes | <10 epoch training or hyperparameter search | | No early stopping | Implement early stopping | ✓ Sometimes | If training time unlimited | | Log every batch | Log every 10-100 batches | ✓ Often | Reduces I/O overhead | | GPU cache every epoch | Clear GPU cache periodically | ✓ Sometimes | Only if OOM issues | ## Red Flags: Immediate Warning Signs 1. **Training loss much lower than validation loss** (>2x) → Overfitting 2. **Loss spikes on resume** → Optimizer state not loaded 3. **GPU memory grows over time** → Memory leak, likely tensor accumulation 4. **Validation never runs** → Check if validation is in loop 5. **Best model not saved** → Check checkpoint logic 6. **Different results on resume** → Scheduler not loaded 7. **Early stopping not working** → Checkpoint not at best model 8. **OOM during training** → Clear GPU cache, check for accumulated tensors ## Testing Your Training Loop ```python def test_training_loop(): """Quick test to verify training loop is correct.""" # Create dummy data X_train = torch.randn(100, 10) y_train = torch.randint(0, 2, (100,)) X_val = torch.randn(20, 10) y_val = torch.randint(0, 2, (20,)) train_loader = DataLoader( list(zip(X_train, y_train)), batch_size=16 ) val_loader = DataLoader( list(zip(X_val, y_val)), batch_size=16 ) # Simple model model = nn.Sequential( nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 2) ) # Training optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) criterion = nn.CrossEntropyLoss() loop = TrainingLoop(model, optimizer, scheduler, criterion) # Should complete without errors loop.train(train_loader, val_loader, num_epochs=5, checkpoint_dir='test_ckpts') # Check outputs assert len(loop.train_losses) == 5 assert len(loop.val_losses) == 5 assert all(isinstance(l, float) for l in loop.train_losses) print("✓ Training loop test passed") if __name__ == '__main__': test_training_loop() ```