""" Template for creating a PyTorch Lightning Module. This template provides a complete boilerplate for building a LightningModule with all essential methods and best practices. """ import lightning as L import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau class TemplateLightningModule(L.LightningModule): """ Template LightningModule for building deep learning models. Args: learning_rate: Learning rate for optimizer hidden_dim: Hidden dimension size dropout: Dropout probability """ def __init__( self, learning_rate: float = 0.001, hidden_dim: int = 256, dropout: float = 0.1, ): super().__init__() # Save hyperparameters (accessible via self.hparams) self.save_hyperparameters() # Define your model architecture self.model = nn.Sequential( nn.Linear(784, self.hparams.hidden_dim), nn.ReLU(), nn.Dropout(self.hparams.dropout), nn.Linear(self.hparams.hidden_dim, 10), ) # Optional: Define metrics # from torchmetrics import Accuracy # self.train_accuracy = Accuracy(task="multiclass", num_classes=10) # self.val_accuracy = Accuracy(task="multiclass", num_classes=10) def forward(self, x): """ Forward pass of the model. Args: x: Input tensor Returns: Model output """ return self.model(x) def training_step(self, batch, batch_idx): """ Training step (called for each training batch). Args: batch: Current batch of data batch_idx: Index of the current batch Returns: Loss tensor """ x, y = batch # Forward pass logits = self(x) loss = F.cross_entropy(logits, y) # Calculate accuracy (optional) preds = torch.argmax(logits, dim=1) acc = (preds == y).float().mean() # Log metrics self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) self.log("train/acc", acc, on_step=True, on_epoch=True) self.log("learning_rate", self.optimizers().param_groups[0]["lr"]) return loss def validation_step(self, batch, batch_idx): """ Validation step (called for each validation batch). Args: batch: Current batch of data batch_idx: Index of the current batch """ x, y = batch # Forward pass logits = self(x) loss = F.cross_entropy(logits, y) # Calculate accuracy preds = torch.argmax(logits, dim=1) acc = (preds == y).float().mean() # Log metrics (automatically aggregated across batches) self.log("val/loss", loss, on_epoch=True, prog_bar=True, sync_dist=True) self.log("val/acc", acc, on_epoch=True, prog_bar=True, sync_dist=True) def test_step(self, batch, batch_idx): """ Test step (called for each test batch). Args: batch: Current batch of data batch_idx: Index of the current batch """ x, y = batch # Forward pass logits = self(x) loss = F.cross_entropy(logits, y) # Calculate accuracy preds = torch.argmax(logits, dim=1) acc = (preds == y).float().mean() # Log metrics self.log("test/loss", loss, on_epoch=True) self.log("test/acc", acc, on_epoch=True) def predict_step(self, batch, batch_idx, dataloader_idx=0): """ Prediction step (called for each prediction batch). Args: batch: Current batch of data batch_idx: Index of the current batch dataloader_idx: Index of the dataloader (if multiple) Returns: Predictions """ x, y = batch logits = self(x) preds = torch.argmax(logits, dim=1) return preds def configure_optimizers(self): """ Configure optimizers and learning rate schedulers. Returns: Optimizer and scheduler configuration """ # Define optimizer optimizer = Adam( self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-5, ) # Define scheduler scheduler = ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=5, verbose=True, ) # Return configuration return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "monitor": "val/loss", "interval": "epoch", "frequency": 1, }, } # Optional: Add custom methods for model-specific logic def on_train_epoch_end(self): """Called at the end of each training epoch.""" # Example: Log custom metrics pass def on_validation_epoch_end(self): """Called at the end of each validation epoch.""" # Example: Compute epoch-level metrics pass # Example usage if __name__ == "__main__": # Create model model = TemplateLightningModule( learning_rate=0.001, hidden_dim=256, dropout=0.1, ) # Create trainer trainer = L.Trainer( max_epochs=10, accelerator="auto", devices="auto", logger=True, ) # Train (you need to provide train_dataloader and val_dataloader) # trainer.fit(model, train_dataloader, val_dataloader) print(f"Model created with {model.num_parameters:,} parameters") print(f"Hyperparameters: {model.hparams}")