Initial commit
This commit is contained in:
565
skills/pytorch-lightning/references/data_module.md
Normal file
565
skills/pytorch-lightning/references/data_module.md
Normal file
@@ -0,0 +1,565 @@
|
||||
# LightningDataModule - Comprehensive Guide
|
||||
|
||||
## Overview
|
||||
|
||||
A LightningDataModule is a reusable, shareable class that encapsulates all data processing steps in PyTorch Lightning. It solves the problem of scattered data preparation logic by standardizing how datasets are managed and shared across projects.
|
||||
|
||||
## Core Problem It Solves
|
||||
|
||||
In traditional PyTorch workflows, data handling is fragmented across multiple files, making it difficult to answer questions like:
|
||||
- "What splits did you use?"
|
||||
- "What transforms were applied?"
|
||||
- "How was the data prepared?"
|
||||
|
||||
DataModules centralize this information for reproducibility and reusability.
|
||||
|
||||
## Five Processing Steps
|
||||
|
||||
A DataModule organizes data handling into five phases:
|
||||
|
||||
1. **Download/tokenize/process** - Initial data acquisition
|
||||
2. **Clean and save** - Persist processed data to disk
|
||||
3. **Load into Dataset** - Create PyTorch Dataset objects
|
||||
4. **Apply transforms** - Data augmentation, normalization, etc.
|
||||
5. **Wrap in DataLoader** - Configure batching and loading
|
||||
|
||||
## Main Methods
|
||||
|
||||
### `prepare_data()`
|
||||
Downloads and processes data. Runs only once on a single process (not distributed).
|
||||
|
||||
**Use for:**
|
||||
- Downloading datasets
|
||||
- Tokenizing text
|
||||
- Saving processed data to disk
|
||||
|
||||
**Important:** Do not set state here (e.g., self.x = y). State is not transferred to other processes.
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def prepare_data(self):
|
||||
# Download data (runs once)
|
||||
download_dataset("http://example.com/data.zip", "data/")
|
||||
|
||||
# Tokenize and save (runs once)
|
||||
tokenize_and_save("data/raw/", "data/processed/")
|
||||
```
|
||||
|
||||
### `setup(stage)`
|
||||
Creates datasets and applies transforms. Runs on every process in distributed training.
|
||||
|
||||
**Parameters:**
|
||||
- `stage` - 'fit', 'validate', 'test', or 'predict'
|
||||
|
||||
**Use for:**
|
||||
- Creating train/val/test splits
|
||||
- Building Dataset objects
|
||||
- Applying transforms
|
||||
- Setting state (self.train_dataset = ...)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def setup(self, stage):
|
||||
if stage == 'fit':
|
||||
full_dataset = MyDataset("data/processed/")
|
||||
self.train_dataset, self.val_dataset = random_split(
|
||||
full_dataset, [0.8, 0.2]
|
||||
)
|
||||
|
||||
if stage == 'test':
|
||||
self.test_dataset = MyDataset("data/processed/test/")
|
||||
|
||||
if stage == 'predict':
|
||||
self.predict_dataset = MyDataset("data/processed/predict/")
|
||||
```
|
||||
|
||||
### `train_dataloader()`
|
||||
Returns the training DataLoader.
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
```
|
||||
|
||||
### `val_dataloader()`
|
||||
Returns the validation DataLoader(s).
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def val_dataloader(self):
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
```
|
||||
|
||||
### `test_dataloader()`
|
||||
Returns the test DataLoader(s).
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def test_dataloader(self):
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers
|
||||
)
|
||||
```
|
||||
|
||||
### `predict_dataloader()`
|
||||
Returns the prediction DataLoader(s).
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(
|
||||
self.predict_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers
|
||||
)
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
import torch
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, data_path, transform=None):
|
||||
self.data_path = data_path
|
||||
self.transform = transform
|
||||
self.data = self._load_data()
|
||||
|
||||
def _load_data(self):
|
||||
# Load your data here
|
||||
return torch.randn(1000, 3, 224, 224)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.data[idx]
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
return sample
|
||||
|
||||
class MyDataModule(L.LightningDataModule):
|
||||
def __init__(self, data_dir="./data", batch_size=32, num_workers=4):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
# Transforms
|
||||
self.train_transform = self._get_train_transforms()
|
||||
self.test_transform = self._get_test_transforms()
|
||||
|
||||
def _get_train_transforms(self):
|
||||
# Define training transforms
|
||||
return lambda x: x # Placeholder
|
||||
|
||||
def _get_test_transforms(self):
|
||||
# Define test/val transforms
|
||||
return lambda x: x # Placeholder
|
||||
|
||||
def prepare_data(self):
|
||||
# Download data (runs once on single process)
|
||||
# download_data(self.data_dir)
|
||||
pass
|
||||
|
||||
def setup(self, stage=None):
|
||||
# Create datasets (runs on every process)
|
||||
if stage == 'fit' or stage is None:
|
||||
full_dataset = MyDataset(
|
||||
self.data_dir,
|
||||
transform=self.train_transform
|
||||
)
|
||||
train_size = int(0.8 * len(full_dataset))
|
||||
val_size = len(full_dataset) - train_size
|
||||
self.train_dataset, self.val_dataset = random_split(
|
||||
full_dataset, [train_size, val_size]
|
||||
)
|
||||
|
||||
if stage == 'test' or stage is None:
|
||||
self.test_dataset = MyDataset(
|
||||
self.data_dir,
|
||||
transform=self.test_transform
|
||||
)
|
||||
|
||||
if stage == 'predict':
|
||||
self.predict_dataset = MyDataset(
|
||||
self.data_dir,
|
||||
transform=self.test_transform
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True if self.num_workers > 0 else False
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True if self.num_workers > 0 else False
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers
|
||||
)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(
|
||||
self.predict_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers
|
||||
)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Create DataModule
|
||||
dm = MyDataModule(data_dir="./data", batch_size=64, num_workers=8)
|
||||
|
||||
# Use with Trainer
|
||||
trainer = L.Trainer(max_epochs=10)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
# Test
|
||||
trainer.test(model, datamodule=dm)
|
||||
|
||||
# Predict
|
||||
predictions = trainer.predict(model, datamodule=dm)
|
||||
|
||||
# Or use standalone in PyTorch
|
||||
dm.prepare_data()
|
||||
dm.setup(stage='fit')
|
||||
train_loader = dm.train_dataloader()
|
||||
|
||||
for batch in train_loader:
|
||||
# Your training code
|
||||
pass
|
||||
```
|
||||
|
||||
## Additional Hooks
|
||||
|
||||
### `transfer_batch_to_device(batch, device, dataloader_idx)`
|
||||
Custom logic for moving batches to devices.
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
# Custom transfer logic
|
||||
if isinstance(batch, dict):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
return super().transfer_batch_to_device(batch, device, dataloader_idx)
|
||||
```
|
||||
|
||||
### `on_before_batch_transfer(batch, dataloader_idx)`
|
||||
Augment or modify batch before transferring to device (runs on CPU).
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def on_before_batch_transfer(self, batch, dataloader_idx):
|
||||
# Apply CPU-based augmentations
|
||||
batch['image'] = apply_augmentation(batch['image'])
|
||||
return batch
|
||||
```
|
||||
|
||||
### `on_after_batch_transfer(batch, dataloader_idx)`
|
||||
Augment or modify batch after transferring to device (runs on GPU).
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def on_after_batch_transfer(self, batch, dataloader_idx):
|
||||
# Apply GPU-based augmentations
|
||||
batch['image'] = gpu_augmentation(batch['image'])
|
||||
return batch
|
||||
```
|
||||
|
||||
### `state_dict()` / `load_state_dict(state_dict)`
|
||||
Save and restore DataModule state for checkpointing.
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def state_dict(self):
|
||||
return {"current_fold": self.current_fold}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.current_fold = state_dict["current_fold"]
|
||||
```
|
||||
|
||||
### `teardown(stage)`
|
||||
Cleanup operations after training/testing/prediction.
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def teardown(self, stage):
|
||||
# Clean up resources
|
||||
if stage == 'fit':
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Multiple Validation/Test DataLoaders
|
||||
|
||||
Return a list or dictionary of DataLoaders:
|
||||
|
||||
```python
|
||||
def val_dataloader(self):
|
||||
return [
|
||||
DataLoader(self.val_dataset_1, batch_size=32),
|
||||
DataLoader(self.val_dataset_2, batch_size=32)
|
||||
]
|
||||
|
||||
# Or with names (for logging)
|
||||
def val_dataloader(self):
|
||||
return {
|
||||
"val_easy": DataLoader(self.val_easy, batch_size=32),
|
||||
"val_hard": DataLoader(self.val_hard, batch_size=32)
|
||||
}
|
||||
|
||||
# In LightningModule
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
if dataloader_idx == 0:
|
||||
# Handle val_dataset_1
|
||||
pass
|
||||
else:
|
||||
# Handle val_dataset_2
|
||||
pass
|
||||
```
|
||||
|
||||
### Cross-Validation
|
||||
|
||||
```python
|
||||
class CrossValidationDataModule(L.LightningDataModule):
|
||||
def __init__(self, data_dir, batch_size, num_folds=5):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_folds = num_folds
|
||||
self.current_fold = 0
|
||||
|
||||
def setup(self, stage=None):
|
||||
full_dataset = MyDataset(self.data_dir)
|
||||
fold_size = len(full_dataset) // self.num_folds
|
||||
|
||||
# Create fold indices
|
||||
indices = list(range(len(full_dataset)))
|
||||
val_start = self.current_fold * fold_size
|
||||
val_end = val_start + fold_size
|
||||
|
||||
val_indices = indices[val_start:val_end]
|
||||
train_indices = indices[:val_start] + indices[val_end:]
|
||||
|
||||
self.train_dataset = Subset(full_dataset, train_indices)
|
||||
self.val_dataset = Subset(full_dataset, val_indices)
|
||||
|
||||
def set_fold(self, fold):
|
||||
self.current_fold = fold
|
||||
|
||||
def state_dict(self):
|
||||
return {"current_fold": self.current_fold}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.current_fold = state_dict["current_fold"]
|
||||
|
||||
# Usage
|
||||
dm = CrossValidationDataModule("./data", batch_size=32, num_folds=5)
|
||||
|
||||
for fold in range(5):
|
||||
dm.set_fold(fold)
|
||||
trainer = L.Trainer(max_epochs=10)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
```
|
||||
|
||||
### Hyperparameter Saving
|
||||
|
||||
```python
|
||||
class MyDataModule(L.LightningDataModule):
|
||||
def __init__(self, data_dir, batch_size=32, num_workers=4):
|
||||
super().__init__()
|
||||
# Save hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
def setup(self, stage=None):
|
||||
# Access via self.hparams
|
||||
print(f"Batch size: {self.hparams.batch_size}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Separate prepare_data and setup
|
||||
- `prepare_data()` - Downloads/processes (single process, no state)
|
||||
- `setup()` - Creates datasets (every process, set state)
|
||||
|
||||
### 2. Use stage Parameter
|
||||
Check the stage in `setup()` to avoid unnecessary work:
|
||||
|
||||
```python
|
||||
def setup(self, stage):
|
||||
if stage == 'fit':
|
||||
# Only load train/val data when fitting
|
||||
self.train_dataset = ...
|
||||
self.val_dataset = ...
|
||||
elif stage == 'test':
|
||||
# Only load test data when testing
|
||||
self.test_dataset = ...
|
||||
```
|
||||
|
||||
### 3. Pin Memory for GPU Training
|
||||
Enable `pin_memory=True` in DataLoaders for faster GPU transfer:
|
||||
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
return DataLoader(..., pin_memory=True)
|
||||
```
|
||||
|
||||
### 4. Use Persistent Workers
|
||||
Prevent worker restarts between epochs:
|
||||
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
...,
|
||||
num_workers=4,
|
||||
persistent_workers=True
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Avoid Shuffle in Validation/Test
|
||||
Never shuffle validation or test data:
|
||||
|
||||
```python
|
||||
def val_dataloader(self):
|
||||
return DataLoader(..., shuffle=False) # Never True
|
||||
```
|
||||
|
||||
### 6. Make DataModules Reusable
|
||||
Accept configuration parameters in `__init__`:
|
||||
|
||||
```python
|
||||
class MyDataModule(L.LightningDataModule):
|
||||
def __init__(self, data_dir, batch_size, num_workers, augment=True):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
```
|
||||
|
||||
### 7. Document Data Structure
|
||||
Add docstrings explaining data format and expectations:
|
||||
|
||||
```python
|
||||
class MyDataModule(L.LightningDataModule):
|
||||
"""
|
||||
DataModule for XYZ dataset.
|
||||
|
||||
Data format: (image, label) tuples
|
||||
- image: torch.Tensor of shape (C, H, W)
|
||||
- label: int in range [0, num_classes)
|
||||
|
||||
Args:
|
||||
data_dir: Path to data directory
|
||||
batch_size: Batch size for dataloaders
|
||||
num_workers: Number of data loading workers
|
||||
"""
|
||||
```
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
### 1. Setting State in prepare_data
|
||||
**Wrong:**
|
||||
```python
|
||||
def prepare_data(self):
|
||||
self.dataset = load_data() # State not transferred to other processes!
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
```python
|
||||
def prepare_data(self):
|
||||
download_data() # Only download, no state
|
||||
|
||||
def setup(self, stage):
|
||||
self.dataset = load_data() # Set state here
|
||||
```
|
||||
|
||||
### 2. Not Using stage Parameter
|
||||
**Inefficient:**
|
||||
```python
|
||||
def setup(self, stage):
|
||||
self.train_dataset = load_train()
|
||||
self.val_dataset = load_val()
|
||||
self.test_dataset = load_test() # Loads even when just fitting
|
||||
```
|
||||
|
||||
**Efficient:**
|
||||
```python
|
||||
def setup(self, stage):
|
||||
if stage == 'fit':
|
||||
self.train_dataset = load_train()
|
||||
self.val_dataset = load_val()
|
||||
elif stage == 'test':
|
||||
self.test_dataset = load_test()
|
||||
```
|
||||
|
||||
### 3. Forgetting to Return DataLoaders
|
||||
**Wrong:**
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
DataLoader(self.train_dataset, ...) # Forgot return!
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_dataset, ...)
|
||||
```
|
||||
|
||||
## Integration with Trainer
|
||||
|
||||
```python
|
||||
# Initialize DataModule
|
||||
dm = MyDataModule(data_dir="./data", batch_size=64)
|
||||
|
||||
# All data loading is handled by DataModule
|
||||
trainer = L.Trainer(max_epochs=10)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
# DataModule handles validation too
|
||||
trainer.validate(model, datamodule=dm)
|
||||
|
||||
# And testing
|
||||
trainer.test(model, datamodule=dm)
|
||||
|
||||
# And prediction
|
||||
predictions = trainer.predict(model, datamodule=dm)
|
||||
```
|
||||
Reference in New Issue
Block a user