Files
2025-11-30 08:30:10 +08:00

14 KiB

LightningDataModule - Comprehensive Guide

Overview

A LightningDataModule is a reusable, shareable class that encapsulates all data processing steps in PyTorch Lightning. It solves the problem of scattered data preparation logic by standardizing how datasets are managed and shared across projects.

Core Problem It Solves

In traditional PyTorch workflows, data handling is fragmented across multiple files, making it difficult to answer questions like:

  • "What splits did you use?"
  • "What transforms were applied?"
  • "How was the data prepared?"

DataModules centralize this information for reproducibility and reusability.

Five Processing Steps

A DataModule organizes data handling into five phases:

  1. Download/tokenize/process - Initial data acquisition
  2. Clean and save - Persist processed data to disk
  3. Load into Dataset - Create PyTorch Dataset objects
  4. Apply transforms - Data augmentation, normalization, etc.
  5. Wrap in DataLoader - Configure batching and loading

Main Methods

prepare_data()

Downloads and processes data. Runs only once on a single process (not distributed).

Use for:

  • Downloading datasets
  • Tokenizing text
  • Saving processed data to disk

Important: Do not set state here (e.g., self.x = y). State is not transferred to other processes.

Example:

def prepare_data(self):
    # Download data (runs once)
    download_dataset("http://example.com/data.zip", "data/")

    # Tokenize and save (runs once)
    tokenize_and_save("data/raw/", "data/processed/")

setup(stage)

Creates datasets and applies transforms. Runs on every process in distributed training.

Parameters:

  • stage - 'fit', 'validate', 'test', or 'predict'

Use for:

  • Creating train/val/test splits
  • Building Dataset objects
  • Applying transforms
  • Setting state (self.train_dataset = ...)

Example:

def setup(self, stage):
    if stage == 'fit':
        full_dataset = MyDataset("data/processed/")
        self.train_dataset, self.val_dataset = random_split(
            full_dataset, [0.8, 0.2]
        )

    if stage == 'test':
        self.test_dataset = MyDataset("data/processed/test/")

    if stage == 'predict':
        self.predict_dataset = MyDataset("data/processed/predict/")

train_dataloader()

Returns the training DataLoader.

Example:

def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        pin_memory=True
    )

val_dataloader()

Returns the validation DataLoader(s).

Example:

def val_dataloader(self):
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=True
    )

test_dataloader()

Returns the test DataLoader(s).

Example:

def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers
    )

predict_dataloader()

Returns the prediction DataLoader(s).

Example:

def predict_dataloader(self):
    return DataLoader(
        self.predict_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers
    )

Complete Example

import lightning as L
from torch.utils.data import DataLoader, Dataset, random_split
import torch

class MyDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.data = self._load_data()

    def _load_data(self):
        # Load your data here
        return torch.randn(1000, 3, 224, 224)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

class MyDataModule(L.LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=32, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Transforms
        self.train_transform = self._get_train_transforms()
        self.test_transform = self._get_test_transforms()

    def _get_train_transforms(self):
        # Define training transforms
        return lambda x: x  # Placeholder

    def _get_test_transforms(self):
        # Define test/val transforms
        return lambda x: x  # Placeholder

    def prepare_data(self):
        # Download data (runs once on single process)
        # download_data(self.data_dir)
        pass

    def setup(self, stage=None):
        # Create datasets (runs on every process)
        if stage == 'fit' or stage is None:
            full_dataset = MyDataset(
                self.data_dir,
                transform=self.train_transform
            )
            train_size = int(0.8 * len(full_dataset))
            val_size = len(full_dataset) - train_size
            self.train_dataset, self.val_dataset = random_split(
                full_dataset, [train_size, val_size]
            )

        if stage == 'test' or stage is None:
            self.test_dataset = MyDataset(
                self.data_dir,
                transform=self.test_transform
            )

        if stage == 'predict':
            self.predict_dataset = MyDataset(
                self.data_dir,
                transform=self.test_transform
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )

    def predict_dataloader(self):
        return DataLoader(
            self.predict_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )

Usage

# Create DataModule
dm = MyDataModule(data_dir="./data", batch_size=64, num_workers=8)

# Use with Trainer
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, datamodule=dm)

# Test
trainer.test(model, datamodule=dm)

# Predict
predictions = trainer.predict(model, datamodule=dm)

# Or use standalone in PyTorch
dm.prepare_data()
dm.setup(stage='fit')
train_loader = dm.train_dataloader()

for batch in train_loader:
    # Your training code
    pass

Additional Hooks

transfer_batch_to_device(batch, device, dataloader_idx)

Custom logic for moving batches to devices.

Example:

def transfer_batch_to_device(self, batch, device, dataloader_idx):
    # Custom transfer logic
    if isinstance(batch, dict):
        return {k: v.to(device) for k, v in batch.items()}
    return super().transfer_batch_to_device(batch, device, dataloader_idx)

on_before_batch_transfer(batch, dataloader_idx)

Augment or modify batch before transferring to device (runs on CPU).

Example:

def on_before_batch_transfer(self, batch, dataloader_idx):
    # Apply CPU-based augmentations
    batch['image'] = apply_augmentation(batch['image'])
    return batch

on_after_batch_transfer(batch, dataloader_idx)

Augment or modify batch after transferring to device (runs on GPU).

Example:

def on_after_batch_transfer(self, batch, dataloader_idx):
    # Apply GPU-based augmentations
    batch['image'] = gpu_augmentation(batch['image'])
    return batch

state_dict() / load_state_dict(state_dict)

Save and restore DataModule state for checkpointing.

Example:

def state_dict(self):
    return {"current_fold": self.current_fold}

def load_state_dict(self, state_dict):
    self.current_fold = state_dict["current_fold"]

teardown(stage)

Cleanup operations after training/testing/prediction.

Example:

def teardown(self, stage):
    # Clean up resources
    if stage == 'fit':
        self.train_dataset = None
        self.val_dataset = None

Advanced Patterns

Multiple Validation/Test DataLoaders

Return a list or dictionary of DataLoaders:

def val_dataloader(self):
    return [
        DataLoader(self.val_dataset_1, batch_size=32),
        DataLoader(self.val_dataset_2, batch_size=32)
    ]

# Or with names (for logging)
def val_dataloader(self):
    return {
        "val_easy": DataLoader(self.val_easy, batch_size=32),
        "val_hard": DataLoader(self.val_hard, batch_size=32)
    }

# In LightningModule
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    if dataloader_idx == 0:
        # Handle val_dataset_1
        pass
    else:
        # Handle val_dataset_2
        pass

Cross-Validation

class CrossValidationDataModule(L.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_folds=5):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_folds = num_folds
        self.current_fold = 0

    def setup(self, stage=None):
        full_dataset = MyDataset(self.data_dir)
        fold_size = len(full_dataset) // self.num_folds

        # Create fold indices
        indices = list(range(len(full_dataset)))
        val_start = self.current_fold * fold_size
        val_end = val_start + fold_size

        val_indices = indices[val_start:val_end]
        train_indices = indices[:val_start] + indices[val_end:]

        self.train_dataset = Subset(full_dataset, train_indices)
        self.val_dataset = Subset(full_dataset, val_indices)

    def set_fold(self, fold):
        self.current_fold = fold

    def state_dict(self):
        return {"current_fold": self.current_fold}

    def load_state_dict(self, state_dict):
        self.current_fold = state_dict["current_fold"]

# Usage
dm = CrossValidationDataModule("./data", batch_size=32, num_folds=5)

for fold in range(5):
    dm.set_fold(fold)
    trainer = L.Trainer(max_epochs=10)
    trainer.fit(model, datamodule=dm)

Hyperparameter Saving

class MyDataModule(L.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=4):
        super().__init__()
        # Save hyperparameters
        self.save_hyperparameters()

    def setup(self, stage=None):
        # Access via self.hparams
        print(f"Batch size: {self.hparams.batch_size}")

Best Practices

1. Separate prepare_data and setup

  • prepare_data() - Downloads/processes (single process, no state)
  • setup() - Creates datasets (every process, set state)

2. Use stage Parameter

Check the stage in setup() to avoid unnecessary work:

def setup(self, stage):
    if stage == 'fit':
        # Only load train/val data when fitting
        self.train_dataset = ...
        self.val_dataset = ...
    elif stage == 'test':
        # Only load test data when testing
        self.test_dataset = ...

3. Pin Memory for GPU Training

Enable pin_memory=True in DataLoaders for faster GPU transfer:

def train_dataloader(self):
    return DataLoader(..., pin_memory=True)

4. Use Persistent Workers

Prevent worker restarts between epochs:

def train_dataloader(self):
    return DataLoader(
        ...,
        num_workers=4,
        persistent_workers=True
    )

5. Avoid Shuffle in Validation/Test

Never shuffle validation or test data:

def val_dataloader(self):
    return DataLoader(..., shuffle=False)  # Never True

6. Make DataModules Reusable

Accept configuration parameters in __init__:

class MyDataModule(L.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers, augment=True):
        super().__init__()
        self.save_hyperparameters()

7. Document Data Structure

Add docstrings explaining data format and expectations:

class MyDataModule(L.LightningDataModule):
    """
    DataModule for XYZ dataset.

    Data format: (image, label) tuples
    - image: torch.Tensor of shape (C, H, W)
    - label: int in range [0, num_classes)

    Args:
        data_dir: Path to data directory
        batch_size: Batch size for dataloaders
        num_workers: Number of data loading workers
    """

Common Pitfalls

1. Setting State in prepare_data

Wrong:

def prepare_data(self):
    self.dataset = load_data()  # State not transferred to other processes!

Correct:

def prepare_data(self):
    download_data()  # Only download, no state

def setup(self, stage):
    self.dataset = load_data()  # Set state here

2. Not Using stage Parameter

Inefficient:

def setup(self, stage):
    self.train_dataset = load_train()
    self.val_dataset = load_val()
    self.test_dataset = load_test()  # Loads even when just fitting

Efficient:

def setup(self, stage):
    if stage == 'fit':
        self.train_dataset = load_train()
        self.val_dataset = load_val()
    elif stage == 'test':
        self.test_dataset = load_test()

3. Forgetting to Return DataLoaders

Wrong:

def train_dataloader(self):
    DataLoader(self.train_dataset, ...)  # Forgot return!

Correct:

def train_dataloader(self):
    return DataLoader(self.train_dataset, ...)

Integration with Trainer

# Initialize DataModule
dm = MyDataModule(data_dir="./data", batch_size=64)

# All data loading is handled by DataModule
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, datamodule=dm)

# DataModule handles validation too
trainer.validate(model, datamodule=dm)

# And testing
trainer.test(model, datamodule=dm)

# And prediction
predictions = trainer.predict(model, datamodule=dm)