Files
2025-11-30 08:30:10 +08:00

14 KiB

Distributed Training - Comprehensive Guide

Overview

PyTorch Lightning provides several strategies for training large models efficiently across multiple GPUs, nodes, and machines. Choose the right strategy based on model size and hardware configuration.

Strategy Selection Guide

When to Use Each Strategy

Regular Training (Single Device)

  • Model size: Any size that fits in single GPU memory
  • Use case: Prototyping, small models, debugging

DDP (Distributed Data Parallel)

  • Model size: <500M parameters (e.g., ResNet50 ~80M parameters)
  • When: Weights, activations, optimizer states, and gradients all fit in GPU memory
  • Goal: Scale batch size and speed across multiple GPUs
  • Best for: Most standard deep learning models

FSDP (Fully Sharded Data Parallel)

  • Model size: 500M+ parameters (e.g., large transformers like BERT-Large, GPT)
  • When: Model doesn't fit in single GPU memory
  • Recommended for: Users new to model parallelism or migrating from DDP
  • Features: Activation checkpointing, CPU parameter offloading

DeepSpeed

  • Model size: 500M+ parameters
  • When: Need cutting-edge features or already familiar with DeepSpeed
  • Features: CPU/disk parameter offloading, distributed checkpoints, fine-grained control
  • Trade-off: More complex configuration

DDP (Distributed Data Parallel)

Basic Usage

# Single GPU
trainer = L.Trainer(accelerator="gpu", devices=1)

# Multi-GPU on single node (automatic DDP)
trainer = L.Trainer(accelerator="gpu", devices=4)

# Explicit DDP strategy
trainer = L.Trainer(strategy="ddp", accelerator="gpu", devices=4)

Multi-Node DDP

# On each node, run:
trainer = L.Trainer(
    strategy="ddp",
    accelerator="gpu",
    devices=4,  # GPUs per node
    num_nodes=4  # Total nodes
)

DDP Configuration

from lightning.pytorch.strategies import DDPStrategy

trainer = L.Trainer(
    strategy=DDPStrategy(
        process_group_backend="nccl",  # "nccl" for GPU, "gloo" for CPU
        find_unused_parameters=False,   # Set True if model has unused parameters
        gradient_as_bucket_view=True    # More memory efficient
    ),
    accelerator="gpu",
    devices=4
)

DDP Spawn

Use when ddp causes issues (slower but more compatible):

trainer = L.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4)

Best Practices for DDP

  1. Batch size: Multiply by number of GPUs

    # If using 4 GPUs, effective batch size = batch_size * 4
    dm = MyDataModule(batch_size=32)  # 32 * 4 = 128 effective batch size
    
  2. Learning rate: Often scaled with batch size

    # Linear scaling rule
    base_lr = 0.001
    num_gpus = 4
    lr = base_lr * num_gpus
    
  3. Synchronization: Use sync_dist=True for metrics

    self.log("val_loss", loss, sync_dist=True)
    
  4. Rank-specific operations: Use decorators for main process only

    from lightning.pytorch.utilities import rank_zero_only
    
    @rank_zero_only
    def save_results(self):
        # Only runs on main process (rank 0)
        torch.save(self.results, "results.pt")
    

FSDP (Fully Sharded Data Parallel)

Basic Usage

trainer = L.Trainer(
    strategy="fsdp",
    accelerator="gpu",
    devices=4
)

FSDP Configuration

from lightning.pytorch.strategies import FSDPStrategy
import torch.nn as nn

trainer = L.Trainer(
    strategy=FSDPStrategy(
        # Sharding strategy
        sharding_strategy="FULL_SHARD",  # or "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"

        # Activation checkpointing (save memory)
        activation_checkpointing_policy={nn.TransformerEncoderLayer},

        # CPU offloading (save GPU memory, slower)
        cpu_offload=False,

        # Mixed precision
        mixed_precision=True,

        # Wrap policy (auto-wrap layers)
        auto_wrap_policy=None
    ),
    accelerator="gpu",
    devices=8,
    precision="bf16-mixed"
)

Sharding Strategies

FULL_SHARD (default)

  • Shards optimizer states, gradients, and parameters
  • Maximum memory savings
  • More communication overhead

SHARD_GRAD_OP

  • Shards optimizer states and gradients only
  • Parameters kept on all devices
  • Less memory savings but faster

NO_SHARD

  • No sharding (equivalent to DDP)
  • For comparison or when sharding not needed

HYBRID_SHARD

  • Combines FULL_SHARD within nodes and NO_SHARD across nodes
  • Good for multi-node setups

Activation Checkpointing

Trade computation for memory:

from lightning.pytorch.strategies import FSDPStrategy
import torch.nn as nn

# Checkpoint specific layer types
trainer = L.Trainer(
    strategy=FSDPStrategy(
        activation_checkpointing_policy={
            nn.TransformerEncoderLayer,
            nn.TransformerDecoderLayer
        }
    )
)

CPU Offloading

Offload parameters to CPU when not in use:

trainer = L.Trainer(
    strategy=FSDPStrategy(
        cpu_offload=True  # Slower but saves GPU memory
    ),
    accelerator="gpu",
    devices=4
)

FSDP with Large Models

from lightning.pytorch.strategies import FSDPStrategy
import torch.nn as nn

class LargeTransformer(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=4096, nhead=32),
            num_layers=48
        )

    def configure_sharded_model(self):
        # Called by FSDP to wrap model
        pass

# Train
trainer = L.Trainer(
    strategy=FSDPStrategy(
        activation_checkpointing_policy={nn.TransformerEncoderLayer},
        cpu_offload=False,
        sharding_strategy="FULL_SHARD"
    ),
    accelerator="gpu",
    devices=8,
    precision="bf16-mixed",
    max_epochs=10
)

model = LargeTransformer()
trainer.fit(model, datamodule=dm)

DeepSpeed

Installation

pip install deepspeed

Basic Usage

trainer = L.Trainer(
    strategy="deepspeed_stage_2",  # or "deepspeed_stage_3"
    accelerator="gpu",
    devices=4,
    precision="16-mixed"
)

DeepSpeed Stages

Stage 1: Optimizer State Sharding

  • Shards optimizer states
  • Moderate memory savings
trainer = L.Trainer(strategy="deepspeed_stage_1")

Stage 2: Optimizer + Gradient Sharding

  • Shards optimizer states and gradients
  • Good memory savings
trainer = L.Trainer(strategy="deepspeed_stage_2")

Stage 3: Full Model Sharding (ZeRO-3)

  • Shards optimizer states, gradients, and model parameters
  • Maximum memory savings
  • Can train very large models
trainer = L.Trainer(strategy="deepspeed_stage_3")

Stage 2 with Offloading

  • Offload to CPU or NVMe
trainer = L.Trainer(strategy="deepspeed_stage_2_offload")
trainer = L.Trainer(strategy="deepspeed_stage_3_offload")

DeepSpeed Configuration File

For fine-grained control:

from lightning.pytorch.strategies import DeepSpeedStrategy

# Create config file: ds_config.json
config = {
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9
    },
    "fp16": {
        "enabled": True,
        "loss_scale": 0,
        "initial_scale_power": 16,
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "gradient_clipping": 1.0,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto"
}

trainer = L.Trainer(
    strategy=DeepSpeedStrategy(config=config),
    accelerator="gpu",
    devices=8,
    precision="16-mixed"
)

DeepSpeed Best Practices

  1. Use Stage 2 for models <10B parameters
  2. Use Stage 3 for models >10B parameters
  3. Enable offloading if GPU memory is insufficient
  4. Tune reduce_bucket_size for communication efficiency

Comparison Table

Feature DDP FSDP DeepSpeed
Model Size <500M params 500M+ params 500M+ params
Memory Efficiency Low High Very High
Speed Fastest Fast Fast
Setup Complexity Simple Medium Complex
Offloading No CPU CPU + Disk
Best For Standard models Large models Very large models
Configuration Minimal Moderate Extensive

Mixed Precision Training

Use mixed precision to speed up training and save memory:

# FP16 mixed precision
trainer = L.Trainer(precision="16-mixed")

# BFloat16 mixed precision (A100, H100)
trainer = L.Trainer(precision="bf16-mixed")

# Full precision (default)
trainer = L.Trainer(precision="32-true")

# Double precision
trainer = L.Trainer(precision="64-true")

Mixed Precision with Different Strategies

# DDP + FP16
trainer = L.Trainer(
    strategy="ddp",
    accelerator="gpu",
    devices=4,
    precision="16-mixed"
)

# FSDP + BFloat16
trainer = L.Trainer(
    strategy="fsdp",
    accelerator="gpu",
    devices=8,
    precision="bf16-mixed"
)

# DeepSpeed + FP16
trainer = L.Trainer(
    strategy="deepspeed_stage_2",
    accelerator="gpu",
    devices=4,
    precision="16-mixed"
)

Multi-Node Training

SLURM

#!/bin/bash
#SBATCH --nodes=4
#SBATCH --gpus-per-node=4
#SBATCH --time=24:00:00

srun python train.py
# train.py
trainer = L.Trainer(
    strategy="ddp",
    accelerator="gpu",
    devices=4,
    num_nodes=4
)

Manual Multi-Node Setup

Node 0 (master):

python train.py --num_nodes=2 --node_rank=0 --master_addr=192.168.1.1 --master_port=12345

Node 1:

python train.py --num_nodes=2 --node_rank=1 --master_addr=192.168.1.1 --master_port=12345
# train.py
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--node_rank", type=int, default=0)
parser.add_argument("--master_addr", type=str, default="localhost")
parser.add_argument("--master_port", type=int, default=12345)
args = parser.parse_args()

trainer = L.Trainer(
    strategy="ddp",
    accelerator="gpu",
    devices=4,
    num_nodes=args.num_nodes
)

Common Patterns

Gradient Accumulation with DDP

# Simulate larger batch size
trainer = L.Trainer(
    strategy="ddp",
    accelerator="gpu",
    devices=4,
    accumulate_grad_batches=4  # Effective batch size = batch_size * devices * 4
)

Model Checkpoint with Distributed Training

from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    save_top_k=3,
    mode="min"
)

trainer = L.Trainer(
    strategy="ddp",
    accelerator="gpu",
    devices=4,
    callbacks=[checkpoint_callback]
)

Reproducibility in Distributed Training

import lightning as L

L.seed_everything(42, workers=True)

trainer = L.Trainer(
    strategy="ddp",
    accelerator="gpu",
    devices=4,
    deterministic=True
)

Troubleshooting

NCCL Timeout

Increase timeout for slow networks:

import os
os.environ["NCCL_TIMEOUT"] = "3600"  # 1 hour

trainer = L.Trainer(strategy="ddp", accelerator="gpu", devices=4)

CUDA Out of Memory

Solutions:

  1. Enable gradient checkpointing
  2. Reduce batch size
  3. Use FSDP or DeepSpeed
  4. Enable CPU offloading
  5. Use mixed precision
# Option 1: Gradient checkpointing
class MyModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MyTransformer()
        self.model.gradient_checkpointing_enable()

# Option 2: Smaller batch size
dm = MyDataModule(batch_size=16)  # Reduce from 32

# Option 3: FSDP with offloading
trainer = L.Trainer(
    strategy=FSDPStrategy(cpu_offload=True),
    precision="bf16-mixed"
)

# Option 4: Gradient accumulation
trainer = L.Trainer(accumulate_grad_batches=4)

Distributed Sampler Issues

Lightning handles DistributedSampler automatically:

# Don't do this
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(dataset)  # Lightning does this automatically

# Just use shuffle
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

Communication Overhead

Reduce communication with larger find_unused_parameters:

trainer = L.Trainer(
    strategy=DDPStrategy(find_unused_parameters=False),
    accelerator="gpu",
    devices=4
)

Best Practices

1. Start with Single GPU

Test your code on single GPU before scaling:

# Debug on single GPU
trainer = L.Trainer(accelerator="gpu", devices=1, fast_dev_run=True)

# Then scale to multiple GPUs
trainer = L.Trainer(accelerator="gpu", devices=4, strategy="ddp")

2. Use Appropriate Strategy

  • <500M params: Use DDP
  • 500M-10B params: Use FSDP
  • 10B params: Use DeepSpeed Stage 3

3. Enable Mixed Precision

Always use mixed precision for modern GPUs:

trainer = L.Trainer(precision="bf16-mixed")  # A100, H100
trainer = L.Trainer(precision="16-mixed")    # V100, T4

4. Scale Hyperparameters

Adjust learning rate and batch size when scaling:

# Linear scaling rule
lr = base_lr * num_gpus

5. Sync Metrics

Always sync metrics in distributed training:

self.log("val_loss", loss, sync_dist=True)

6. Use Rank-Zero Operations

File I/O and expensive operations on main process only:

from lightning.pytorch.utilities import rank_zero_only

@rank_zero_only
def save_predictions(self):
    torch.save(self.predictions, "predictions.pt")

7. Checkpoint Regularly

Save checkpoints to resume from failures:

checkpoint_callback = ModelCheckpoint(
    save_top_k=3,
    save_last=True,  # Always save last for resuming
    every_n_epochs=5
)