""" Template for creating a PyTorch Lightning DataModule. This template provides a complete boilerplate for building a LightningDataModule with all essential methods and best practices for data handling. """ import lightning as L from torch.utils.data import Dataset, DataLoader, random_split import torch class CustomDataset(Dataset): """ Custom Dataset implementation. Replace this with your actual dataset implementation. """ def __init__(self, data_path, transform=None): """ Initialize the dataset. Args: data_path: Path to data directory transform: Optional transforms to apply """ self.data_path = data_path self.transform = transform # Load your data here # self.data = load_data(data_path) # self.labels = load_labels(data_path) # Placeholder data self.data = torch.randn(1000, 3, 224, 224) self.labels = torch.randint(0, 10, (1000,)) def __len__(self): """Return the size of the dataset.""" return len(self.data) def __getitem__(self, idx): """ Get a single item from the dataset. Args: idx: Index of the item Returns: Tuple of (data, label) """ sample = self.data[idx] label = self.labels[idx] if self.transform: sample = self.transform(sample) return sample, label class TemplateDataModule(L.LightningDataModule): """ Template LightningDataModule for data handling. This class encapsulates all data processing steps: 1. Download/prepare data (prepare_data) 2. Create datasets (setup) 3. Create dataloaders (train/val/test/predict_dataloader) Args: data_dir: Directory containing the data batch_size: Batch size for dataloaders num_workers: Number of workers for data loading train_val_split: Train/validation split ratio pin_memory: Whether to pin memory for faster GPU transfer """ def __init__( self, data_dir: str = "./data", batch_size: int = 32, num_workers: int = 4, train_val_split: float = 0.8, pin_memory: bool = True, ): super().__init__() # Save hyperparameters self.save_hyperparameters() # Initialize as None (will be set in setup) self.train_dataset = None self.val_dataset = None self.test_dataset = None self.predict_dataset = None def prepare_data(self): """ Download and prepare data. This method is called only once and on a single process. Do not set state here (e.g., self.x = y) because it's not transferred to other processes. Use this for: - Downloading datasets - Tokenizing text - Saving processed data to disk """ # Example: Download data if not exists # if not os.path.exists(self.hparams.data_dir): # download_dataset(self.hparams.data_dir) # Example: Process and save data # process_and_save(self.hparams.data_dir) pass def setup(self, stage: str = None): """ Create datasets for each stage. This method is called on every process in distributed training. Set state here (e.g., self.train_dataset = ...). Args: stage: Current stage ('fit', 'validate', 'test', or 'predict') """ # Define transforms train_transform = self._get_train_transforms() test_transform = self._get_test_transforms() # Setup for training and validation if stage == "fit" or stage is None: # Load full dataset full_dataset = CustomDataset( self.hparams.data_dir, transform=train_transform ) # Split into train and validation train_size = int(self.hparams.train_val_split * len(full_dataset)) val_size = len(full_dataset) - train_size self.train_dataset, self.val_dataset = random_split( full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42), ) # Apply test transforms to validation set # (Note: random_split doesn't support different transforms, # you may need to implement a custom wrapper) # Setup for testing if stage == "test" or stage is None: self.test_dataset = CustomDataset( self.hparams.data_dir, transform=test_transform ) # Setup for prediction if stage == "predict": self.predict_dataset = CustomDataset( self.hparams.data_dir, transform=test_transform ) def _get_train_transforms(self): """ Define training transforms/augmentations. Returns: Training transforms """ # Example with torchvision: # from torchvision import transforms # return transforms.Compose([ # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(10), # transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # ]) return None def _get_test_transforms(self): """ Define test/validation transforms (no augmentation). Returns: Test/validation transforms """ # Example with torchvision: # from torchvision import transforms # return transforms.Compose([ # transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # ]) return None def train_dataloader(self): """ Create training dataloader. Returns: Training DataLoader """ return DataLoader( self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, persistent_workers=True if self.hparams.num_workers > 0 else False, drop_last=True, # Drop last incomplete batch ) def val_dataloader(self): """ Create validation dataloader. Returns: Validation DataLoader """ return DataLoader( self.val_dataset, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, persistent_workers=True if self.hparams.num_workers > 0 else False, ) def test_dataloader(self): """ Create test dataloader. Returns: Test DataLoader """ return DataLoader( self.test_dataset, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, ) def predict_dataloader(self): """ Create prediction dataloader. Returns: Prediction DataLoader """ return DataLoader( self.predict_dataset, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, ) # Optional: State management for checkpointing def state_dict(self): """ Save DataModule state for checkpointing. Returns: State dictionary """ return {"train_val_split": self.hparams.train_val_split} def load_state_dict(self, state_dict): """ Restore DataModule state from checkpoint. Args: state_dict: State dictionary """ self.hparams.train_val_split = state_dict["train_val_split"] # Optional: Teardown for cleanup def teardown(self, stage: str = None): """ Cleanup after training/testing/prediction. Args: stage: Current stage ('fit', 'validate', 'test', or 'predict') """ # Clean up resources if stage == "fit": self.train_dataset = None self.val_dataset = None elif stage == "test": self.test_dataset = None elif stage == "predict": self.predict_dataset = None # Example usage if __name__ == "__main__": # Create DataModule dm = TemplateDataModule( data_dir="./data", batch_size=64, num_workers=8, train_val_split=0.8, ) # Setup for training dm.prepare_data() dm.setup(stage="fit") # Get dataloaders train_loader = dm.train_dataloader() val_loader = dm.val_dataloader() print(f"Train dataset size: {len(dm.train_dataset)}") print(f"Validation dataset size: {len(dm.val_dataset)}") print(f"Train batches: {len(train_loader)}") print(f"Validation batches: {len(val_loader)}") # Example: Use with Trainer # from template_lightning_module import TemplateLightningModule # model = TemplateLightningModule() # trainer = L.Trainer(max_epochs=10) # trainer.fit(model, datamodule=dm)