Files
2025-11-30 09:00:11 +08:00

1484 lines
45 KiB
Markdown

# Data Augmentation Strategies
## Overview
Data augmentation artificially increases training data diversity by applying transformations that preserve labels. This is one of the most cost-effective ways to improve model robustness and reduce overfitting, but it requires domain knowledge and careful strength tuning.
**Core Principle**: Augmentation is NOT a universal technique. The right augmentations depend on your domain, task, data distribution, and model capacity. Wrong augmentations can hurt more than help.
**Critical Rule**: Augment ONLY training data. Validation and test data must remain unaugmented to provide accurate performance estimates.
**Why Augmentation Matters**:
- Creates label-preserving variations, teaching invariance
- Reduces overfitting by preventing memorization
- Improves robustness to distribution shift
- Essentially "free" data—no labeling cost
- Can outperform adding more labeled data in some domains
## When to Use This Skill
Load this skill when:
- Training on limited dataset (< 10,000 examples) and seeing overfitting
- Addressing distribution shift or robustness concerns
- Selecting augmentations for vision, NLP, audio, or tabular tasks
- Designing augmentation pipelines and strength tuning
- Troubleshooting training issues (accuracy drop with augmentation)
- Implementing test-time augmentation (TTA) or augmentation policies
- Choosing between weak augmentation (100% prob) vs strong (lower prob)
**Don't use for**: General training debugging (use using-training-optimization), optimization algorithm selection (use optimization-algorithms), regularization without domain context (augmentation is domain-specific)
## Part 1: Augmentation Decision Framework
### The Core Question: "When should I augment?"
**WRONG ANSWER**: "Use augmentation for all datasets."
**RIGHT APPROACH**: Use this decision framework.
### Clarifying Questions
1. **"How much training data do you have?"**
- < 1,000 examples → Strong augmentation needed
- 1,000-10,000 examples → Medium augmentation
- 10,000-100,000 examples → Light augmentation often sufficient
- > 100,000 examples → Augmentation helps but not critical
- Rule: Smaller dataset = more aggressive augmentation
2. **"What's your train/validation accuracy gap?"**
- Train 90%, val 70% (20% gap) → Overfitting, augmentation will help
- Train 85%, val 83% (2% gap) → Well-regularized, augmentation optional
- Train 60%, val 58% (2% gap) → Underfitting, augmentation won't help (need more capacity)
- Rule: Large gap indicates augmentation will help
3. **"How much distribution shift is expected at test time?"**
- Same domain, clean images → Light augmentation (rotation ±15°, crop 90%, brightness ±10%)
- Real-world conditions → Medium augmentation (rotation ±30°, crop 75%, brightness ±20%)
- Extreme conditions (weather, blur) → Strong augmentation + robust architectures
- Rule: Augment for expected shift, not beyond
4. **"What's your domain?"**
- Vision → Rich augmentation toolkit available
- NLP → Limited augmentations (preserve syntax/semantics)
- Audio → Time/frequency domain transforms
- Tabular → SMOTE, feature dropout, noise injection
- Rule: Domain determines augmentation types
5. **"Do you have compute budget for increased training time?"**
- Yes → Stronger augmentation possible
- No → Lighter augmentation to save training time
- Rule: Online augmentation adds ~10-20% training time
### Decision Tree
```
START: Should I augment?
├─ Is your training data < 10,000 examples?
│ ├─ YES → Augmentation will likely help. Go to Part 2 (domain selection).
│ │
│ └─ NO → Check train/validation gap...
├─ Is your train-validation accuracy gap > 10%?
│ ├─ YES → Augmentation will likely help. Go to Part 2.
│ │
│ └─ NO → Continue...
├─ Are you in a domain where distribution shift is expected?
│ │ (medical imaging varies by scanner, autonomous driving weather varies,
│ │ satellite imagery has seasonal changes, etc.)
│ ├─ YES → Augmentation will help. Go to Part 2.
│ │
│ └─ NO → Continue...
├─ Do you have compute budget for 10-20% extra training time?
│ ├─ YES, but data is ample → Optional: light augmentation helps margins
│ │ May improve generalization even with large data.
│ │
│ └─ NO → Skip augmentation or use very light augmentation.
└─ DEFAULT: Apply light-to-medium augmentation for target domain.
Start with conservative parameters.
Measure impact before increasing strength.
```
## Part 2: Domain-Specific Augmentation Catalogs
### Vision Augmentations (Image Classification, Detection, Segmentation)
**Key Principle**: Preserve semantic content while varying appearance and geometry.
#### Geometric Transforms (Preserve Class)
**Rotation**:
```python
from torchvision import transforms
transform = transforms.RandomRotation(degrees=15)
# ±15° for most tasks (natural objects rotate ±15°)
# ±30° for synthetic/manufactured objects
# ±45° for symmetric objects (digits, logos)
# Avoid: ±180° (completely unrecognizable)
```
**When to use**: All vision tasks. Rotation-invariance is common.
**Strength tuning**:
- Light: ±5° to ±15° (most conservative)
- Medium: ±15° to ±30°
- Strong: ±30° to ±45° (only for symmetric classes)
- Never: ±180° (makes label ambiguous)
**Domain exceptions**:
- Medical imaging: ±10° maximum (anatomy is not rotation-invariant)
- Satellite: ±5° maximum (geographic north is meaningful)
- Handwriting: ±15° okay (natural variation)
- OCR: ±10° maximum (upside-down is different class)
**Crop (Random Crop + Resize)**:
```python
transform = transforms.RandomResizedCrop(224, scale=(0.8, 1.0))
# Crops 80-100% of original, resizes to 224x224
# Teaches invariance to framing and zoom
```
**When to use**: Classification, detection (with care), segmentation.
**Strength tuning**:
- Light: scale=(0.9, 1.0) - crop 90-100%
- Medium: scale=(0.8, 1.0) - crop 80-100%
- Strong: scale=(0.5, 1.0) - crop 50-100% (can lose important features)
**Domain considerations**:
- Detection: Minimum scale should keep objects ≥50px
- Segmentation: Crops must preserve mask validity
- Medical: Center-biased crops (avoid cutting off pathology)
**Horizontal Flip**:
```python
transform = transforms.RandomHorizontalFlip(p=0.5)
# Mirrors image left-right
```
**When to use**: Most vision tasks WHERE LEFT-RIGHT SYMMETRY IS NATURAL.
**CRITICAL EXCEPTION**:
- ❌ Medical imaging (L/R markers mean something)
- ❌ Text/documents (flipped text is unreadable)
- ❌ Objects with semantic left/right (cars facing direction)
- ❌ Faces (though some datasets use it)
**Safe domains**:
- ✅ Natural scene classification
- ✅ Animal classification (except directional animals)
- ✅ Generic object detection (not vehicles)
**Vertical Flip** (Use Rarely):
```python
transform = transforms.RandomVerticalFlip(p=0.5)
```
**VERY LIMITED USE**: Most natural objects are not up-down symmetric.
- ❌ Most natural images (horizon has direction)
- ❌ Medical imaging (anatomical direction matters)
- ✅ Texture classification (some textures rotationally symmetric)
**Perspective Transform (Affine)**:
```python
transform = transforms.RandomAffine(
degrees=0,
translate=(0.1, 0.1), # ±10% translation
scale=(0.9, 1.1), # ±10% scaling
shear=(-15, 15) # ±15° shear
)
```
**When to use**: Scene understanding, 3D object detection, autonomous driving.
**Caution**: Shear and extreme perspective can make images unrecognizable. Use conservatively.
#### Color and Brightness Transforms (Appearance Variance)
**Color Jitter**:
```python
transform = transforms.ColorJitter(
brightness=0.2, # ±20% brightness
contrast=0.2, # ±20% contrast
saturation=0.2, # ±20% saturation
hue=0.1 # ±10% hue shift
)
```
**When to use**: All vision tasks (teaches color-invariance).
**Strength tuning**:
- Light: brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05
- Medium: brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
- Strong: brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3
**Domain exceptions**:
- Medical imaging: brightness/contrast only (color is artificial)
- Satellite: All channels safe (handles weather/season)
- Thermal imaging: Only brightness meaningful
**Gaussian Blur**:
```python
from torchvision.transforms.functional import gaussian_blur
transform = transforms.GaussianBlur(kernel_size=(3, 7), sigma=(0.1, 2.0))
```
**When to use**: Makes model robust to soft focus, mimics unfocused camera.
**Strength tuning**:
- Light: sigma=(0.1, 0.5)
- Medium: sigma=(0.1, 1.0)
- Strong: sigma=(0.5, 2.0)
**Domain consideration**: Don't blur medical/satellite (loses diagnostic/geographic detail).
**Grayscale**:
```python
transform = transforms.Grayscale(p=0.2) # 20% probability
```
**When to use**: When color information is redundant or unreliable.
**Domain exceptions**:
- Medical imaging: Apply selectively (preserve when color is diagnostic)
- Satellite: Don't apply (multi-spectral bands are essential)
- Natural scene: Safe to apply
#### Mixing Augmentations (Mixup, Cutmix, Cutout)
**Mixup**: Linear interpolation of images and labels
```python
def mixup(x, y, alpha=1.0):
"""Mixup augmentation: blend two images and labels."""
batch_size = x.size(0)
index = torch.randperm(batch_size)
lam = np.random.beta(alpha, alpha) # Sample mixing ratio
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
# Use with soft labels during training:
# loss = lam * loss_fn(pred, y_a) + (1-lam) * loss_fn(pred, y_b)
```
**When to use**: All image classification tasks.
**Strength tuning**:
- Light: alpha=2.0 (blends close to original)
- Medium: alpha=1.0 (uniform blending)
- Strong: alpha=0.2 (extreme blends)
**Effectiveness**: One of the best modern augmentations, ~1-2% accuracy improvement typical.
**Cutmix**: Replace rectangular region with another image
```python
def cutmix(x, y, alpha=1.0):
"""CutMix augmentation: replace rectangular patch."""
batch_size = x.size(0)
index = torch.randperm(batch_size)
lam = np.random.beta(alpha, alpha)
height, width = x.size(2), x.size(3)
# Sample patch coordinates
cut_ratio = np.sqrt(1.0 - lam)
cut_h = int(height * cut_ratio)
cut_w = int(width * cut_ratio)
cx = np.random.randint(0, width)
cy = np.random.randint(0, height)
bbx1 = np.clip(cx - cut_w // 2, 0, width)
bby1 = np.clip(cy - cut_h // 2, 0, height)
bbx2 = np.clip(cx + cut_w // 2, 0, width)
bby2 = np.clip(cy + cut_h // 2, 0, height)
x[index, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1)) / (height * width)
return x, y, y[index], lam
```
**When to use**: Image classification (especially effective).
**Advantage over Mixup**: Preserves spatial structure better, more realistic.
**Typical improvement**: 1-3% accuracy increase.
**Cutout**: Remove rectangular patch (fill with zero/mean)
```python
def cutout(x, patch_size=32, p=0.5):
"""Cutout: remove rectangular region."""
if np.random.rand() > p:
return x
batch_size, _, height, width = x.size()
for i in range(batch_size):
cx = np.random.randint(0, width)
cy = np.random.randint(0, height)
x1 = np.clip(cx - patch_size // 2, 0, width)
y1 = np.clip(cy - patch_size // 2, 0, height)
x2 = np.clip(cx + patch_size // 2, 0, width)
y2 = np.clip(cy + patch_size // 2, 0, height)
x[i, :, y1:y2, x1:x2] = 0
return x
```
**When to use**: Regularization effect, teaches local invariance.
**Typical improvement**: 0.5-1% accuracy increase.
#### AutoAugment and Learned Policies
**RandAugment**: Random selection from augmentation space
```python
from torchvision.transforms import RandAugment
transform = RandAugment(num_ops=2, magnitude=9)
# Apply 2 random augmentations from 14 operation space
# Magnitude 0-30 controls strength
```
**When to use**: When unsure about augmentation selection.
**Advantage**: Removes manual hyperparameter tuning.
**Typical improvement**: 1-2% accuracy compared to manual selection.
**AutoAugment**: Data-dependent learned policy
```python
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
transform = AutoAugment(AutoAugmentPolicy.IMAGENET)
# Predefined policy for ImageNet-like tasks
# Policies: IMAGENET, CIFAR10, SVHN
```
**Pre-trained policies**:
- IMAGENET: General-purpose, vision tasks
- CIFAR10: Smaller images (32x32), high regularization
- SVHN: Street view house numbers
**Typical improvement**: 0.5-1% accuracy.
### NLP Augmentations (Text Classification, QA, Generation)
**Key Principle**: Preserve meaning while varying surface form. Syntax and semantics must be preserved.
#### Rule-Based Augmentations
**Back-Translation**:
```python
def back_translate(text: str, src_lang='en', inter_lang='fr') -> str:
"""Translate to intermediate language and back to create paraphrase."""
# English -> French -> English
# Example: "The cat sat on mat" -> "Le chat s'assit sur le tapis" -> "The cat sat on the mat"
# Use library like transformers or marian-mt
from transformers import MarianMTModel, MarianTokenizer
# Translate en->fr
model_name = f"Helsinki-NLP/Opus-MT-{src_lang}-{inter_lang}"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs)
intermediate = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Translate fr->en
model_name_back = f"Helsinki-NLP/Opus-MT-{inter_lang}-{src_lang}"
tokenizer_back = MarianTokenizer.from_pretrained(model_name_back)
model_back = MarianMTModel.from_pretrained(model_name_back)
inputs_back = tokenizer_back(intermediate, return_tensors="pt")
outputs_back = model_back.generate(**inputs_back)
result = tokenizer_back.batch_decode(outputs_back, skip_special_tokens=True)[0]
return result
```
**When to use**: Text classification, sentiment analysis, intent detection.
**Strength tuning**:
- Use 1-2 intermediate languages
- Probability 0.3-0.5 (paraphrases, not all data)
**Advantage**: Creates natural paraphrases.
**Disadvantage**: Slow (requires neural translation model).
**Synonym Replacement (EDA)**:
```python
import nltk
from nltk.corpus import wordnet
def synonym_replacement(text: str, n=2):
"""Replace n random words with synonyms."""
words = text.split()
new_words = words.copy()
random_word_list = list(set([word for word in words if word.isalnum()]))
random.shuffle(random_word_list)
num_replaced = 0
for random_word in random_word_list:
synonyms = get_synonyms(random_word)
if len(synonyms) > 0:
synonym = random.choice(synonyms)
new_words = [synonym if word == random_word else word for word in new_words]
num_replaced += 1
if num_replaced >= n:
break
return ' '.join(new_words)
def get_synonyms(word):
"""Find synonyms using WordNet."""
synonyms = set()
for syn in wordnet.synsets(word):
for lemma in syn.lemmas():
synonyms.add(lemma.name())
return list(synonyms - {word})
```
**When to use**: Text classification, low-resource languages.
**Strength tuning**:
- n=1-3 synonyms per sentence
- Probability 0.5 (replace in half of training data)
**Typical improvement**: 1-2% for small datasets.
**Random Insertion**:
```python
def random_insertion(text: str, n=2):
"""Insert n random synonyms of random words."""
words = text.split()
new_words = words.copy()
for _ in range(n):
add_word(new_words)
return ' '.join(new_words)
def add_word(new_words):
synonyms = []
counter = 0
while len(synonyms) < 1:
if counter >= 10:
return
random_word = new_words[random.randint(0, len(new_words)-1)]
synonyms = get_synonyms(random_word)
counter += 1
random_synonym = synonyms[random.randint(0, len(synonyms)-1)]
random_idx = random.randint(0, len(new_words)-1)
new_words.insert(random_idx, random_synonym)
```
**When to use**: Text classification, paraphrase detection.
**Random Swap**:
```python
def random_swap(text: str, n=2):
"""Randomly swap positions of n word pairs."""
words = text.split()
new_words = words.copy()
for _ in range(n):
new_words = swap_word(new_words)
return ' '.join(new_words)
def swap_word(new_words):
random_idx_1 = random.randint(0, len(new_words)-1)
random_idx_2 = random_idx_1
counter = 0
while random_idx_2 == random_idx_1:
random_idx_2 = random.randint(0, len(new_words)-1)
counter += 1
if counter > 3:
return new_words
new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
return new_words
```
**When to use**: Robustness to word order variations.
**Random Deletion**:
```python
def random_deletion(text: str, p=0.2):
"""Randomly delete words with probability p."""
if len(text.split()) == 1:
return text
words = text.split()
new_words = [word for word in words if random.uniform(0, 1) > p]
if len(new_words) == 0:
return random.choice(words)
return ' '.join(new_words)
```
**When to use**: Robustness to missing/incomplete input.
#### Sentence-Level Augmentations
**Paraphrase Generation**:
```python
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def paraphrase(text: str):
"""Generate paraphrase using pretrained model."""
model_name = "Vamsi/T5_Paraphrase_Paws"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
input_ids = tokenizer.encode(text, return_tensors="pt")
outputs = model.generate(input_ids)
paraphrase = tokenizer.decode(outputs[0], skip_special_tokens=True)
return paraphrase
```
**When to use**: Text classification with limited data.
**Advantage**: High-quality semantic paraphrases.
**Disadvantage**: Model-dependent, can be slow.
### Audio Augmentations (Speech Recognition, Music)
**Key Principle**: Preserve content while varying acoustic conditions.
**Pitch Shift**:
```python
import librosa
import numpy as np
def pitch_shift(waveform: np.ndarray, sr: int, steps: int):
"""Shift pitch without changing speed."""
# Shift by ±2-4 semitones typical
return librosa.effects.pitch_shift(waveform, sr=sr, n_steps=steps)
# Usage:
audio, sr = librosa.load('audio.wav')
augmented = pitch_shift(audio, sr, steps=np.random.randint(-4, 5))
```
**When to use**: Speech recognition (speaker variation).
**Strength tuning**:
- Light: ±2 semitones
- Medium: ±4 semitones
- Strong: ±8 semitones (avoid, changes phone identity)
**Time Stretching**:
```python
def time_stretch(waveform: np.ndarray, rate: float):
"""Speed up/slow down without changing pitch."""
return librosa.effects.time_stretch(waveform, rate=rate)
# Usage:
augmented = time_stretch(audio, rate=np.random.uniform(0.9, 1.1)) # ±10% speed
```
**When to use**: Speech recognition (speech rate variation).
**Strength tuning**:
- Light: 0.95-1.05 (±5% speed)
- Medium: 0.9-1.1 (±10% speed)
- Strong: 0.8-1.2 (±20% speed, too aggressive)
**Background Noise Injection**:
```python
def add_background_noise(waveform: np.ndarray, noise: np.ndarray, snr_db: float):
"""Add noise at specified SNR (signal-to-noise ratio)."""
signal_power = np.mean(waveform ** 2)
snr_linear = 10 ** (snr_db / 10)
noise_power = signal_power / snr_linear
noise_scaled = noise * np.sqrt(noise_power / np.mean(noise ** 2))
# Mix only first len(waveform) samples of noise
augmented = waveform + noise_scaled[:len(waveform)]
return np.clip(augmented, -1, 1) # Prevent clipping
# Usage:
noise, _ = librosa.load('background_noise.wav', sr=sr)
augmented = add_background_noise(audio, noise, snr_db=np.random.uniform(15, 30))
```
**When to use**: Speech recognition, robustness to noisy environments.
**Strength tuning**:
- Light: SNR 30-40 dB (minimal noise)
- Medium: SNR 20-30 dB (moderate noise)
- Strong: SNR 10-20 dB (very noisy, challenging)
**SpecAugment**: Augmentation in spectrogram space
```python
def spec_augment(mel_spec: np.ndarray, freq_mask_width: int, time_mask_width: int):
"""Apply frequency and time masking to mel-spectrogram."""
freq_axis_size = mel_spec.shape[0]
time_axis_size = mel_spec.shape[1]
# Frequency masking
f0 = np.random.randint(0, freq_axis_size - freq_mask_width)
mel_spec[f0:f0+freq_mask_width, :] = 0
# Time masking
t0 = np.random.randint(0, time_axis_size - time_mask_width)
mel_spec[:, t0:t0+time_mask_width] = 0
return mel_spec
# Usage:
mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr)
augmented = spec_augment(mel_spec, freq_mask_width=30, time_mask_width=40)
```
**When to use**: Speech recognition (standard for ASR).
### Tabular Augmentations (Regression, Classification on Structured Data)
**Key Principle**: Preserve relationships between features while adding noise/variation.
**SMOTE (Synthetic Minority Over-sampling)**:
```python
from imblearn.over_sampling import SMOTE
# Balance imbalanced classification
X_train = your_features # shape: (n_samples, n_features)
y_train = your_labels
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
# Now X_resampled has balanced classes with synthetic minority examples
```
**When to use**: Imbalanced classification (rare class oversampling).
**Advantage**: Addresses class imbalance by creating synthetic examples.
**Feature-wise Noise Injection**:
```python
def add_noise_to_features(X: np.ndarray, noise_std: float):
"""Add Gaussian noise to features (percentage of feature std)."""
noise = np.random.normal(0, noise_std, X.shape)
# Scale noise to percentage of feature std
feature_stds = np.std(X, axis=0)
scaled_noise = noise * (feature_stds * noise_std)
return X + scaled_noise
```
**When to use**: Robustness to measurement noise.
**Strength tuning**:
- Light: noise_std=0.01 (1% of feature std)
- Medium: noise_std=0.05 (5% of feature std)
- Strong: noise_std=0.1 (10% of feature std)
**Feature Dropout**:
```python
def feature_dropout(X: np.ndarray, p: float):
"""Randomly set features to zero."""
mask = np.random.binomial(1, 1-p, X.shape)
return X * mask
```
**When to use**: Robustness to missing/unavailable features.
**Strength tuning**:
- p=0.1 (drop 10% of features)
- p=0.2 (drop 20%)
- Avoid p>0.3 (too much information loss)
**Mixup for Tabular Data**:
```python
def mixup_tabular(X: np.ndarray, y: np.ndarray, alpha: float = 1.0):
"""Apply mixup to tabular features."""
batch_size = X.shape[0]
index = np.random.permutation(batch_size)
lam = np.random.beta(alpha, alpha)
X_mixed = lam * X + (1 - lam) * X[index]
y_a, y_b = y, y[index]
return X_mixed, y_a, y_b, lam
```
**When to use**: Regression and classification on tabular data.
## Part 3: Augmentation Strength Tuning
### Conservative vs Aggressive Augmentation
**Principle**: Start conservative, increase gradually. Test impact.
#### Weak Augmentation (100% probability)
Apply light augmentation to ALL training data, EVERY epoch.
```python
weak_augmentation = transforms.Compose([
transforms.RandomRotation(degrees=10),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
])
```
**Typical improvement**: +1-2% accuracy.
**Pros**:
- Consistent, no randomness in augmentation strength
- Easier to reproduce
- Less prone to catastrophic augmentation
**Cons**:
- Each image same number of times
- Less diversity per image
#### Strong Augmentation (Lower Probability)
Apply strong augmentations with 30-50% probability.
```python
strong_augmentation = transforms.Compose([
transforms.RandomRotation(degrees=45),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), shear=(15, 15)),
transforms.RandomPerspective(distortion_scale=0.3),
])
class StrongAugmentationWrapper:
def __init__(self, transform, p=0.3):
self.transform = transform
self.p = p
def __call__(self, x):
if np.random.rand() < self.p:
return self.transform(x)
return x
aug_wrapper = StrongAugmentationWrapper(strong_augmentation, p=0.3)
```
**Typical improvement**: +2-3% accuracy.
**Pros**:
- More diversity
- Better robustness to extreme conditions
**Cons**:
- Risk of too-aggressive augmentation
- Requires careful strength tuning
### Finding Optimal Strength
**Algorithm**:
1. Start with weak augmentation (parameters at 50% of expected range)
2. Train for 1 epoch, measure validation accuracy
3. Keep weak augmentation for full training
4. Increase strength by 25% and retrain
5. Compare final accuracies
6. If accuracy improved, increase further; if hurt, decrease
7. Stop when accuracy plateaus or decreases
**Example**:
```python
# Start: rotation ±10°, brightness ±0.1
# After test 1: accuracy improves, try rotation ±15°, brightness ±0.15
# After test 2: accuracy improves, try rotation ±20°, brightness ±0.2
# After test 3: accuracy decreases, revert to rotation ±15°, brightness ±0.15
```
## Part 4: Test-Time Augmentation (TTA)
**Definition**: Apply augmentation at inference time, average predictions.
```python
def predict_with_tta(model, image, num_augmentations=8):
"""Make predictions with test-time augmentation."""
predictions = []
for _ in range(num_augmentations):
# Apply light augmentation
augmented = augmentation(image)
with torch.no_grad():
pred = model(augmented.unsqueeze(0))
predictions.append(pred.softmax(dim=1))
# Average predictions
final_pred = torch.stack(predictions).mean(dim=0)
return final_pred
```
**When to use**:
- Final evaluation (test set submission)
- Robustness testing
- Post-training calibration
**Don't use for**:
- Validation (metrics must reflect single-pass performance)
- Production inference (too slow, accuracy not worth inference latency)
**Typical improvement**: +0.5-1% accuracy.
**Computational cost**: 8-10x slower inference.
## Part 5: Common Pitfalls and Rationalization
### Pitfall 1: Augmenting Validation/Test Data
**Symptom**: Validation accuracy inflated, test performance poor.
**User Says**: "More diversity helps, so augment everywhere"
**Why It Fails**: Validation measures true performance on ORIGINAL data, not augmented.
**Fix**:
```python
# WRONG:
val_transform = transforms.Compose([
transforms.RandomRotation(20),
transforms.ToTensor(),
])
# RIGHT:
val_transform = transforms.Compose([
transforms.ToTensor(),
])
```
### Pitfall 2: Over-Augmentation (Unrecognizable Images)
**Symptom**: Training loss doesn't decrease, accuracy worse with augmentation.
**User Says**: "More augmentation = more robustness"
**Why It Fails**: If image unrecognizable, model cannot learn the class.
**Fix**: Start conservative. Test incrementally.
### Pitfall 3: Wrong Domain Augmentations
**Symptom**: Accuracy drops with augmentation.
**User Says**: "These augmentations work for images, why not text?"
**Why It Fails**: Flipped text is unreadable. Domain-specific invariances differ.
**Fix**: Use augmentations designed for your domain.
### Pitfall 4: Augmentation Inconsistency Across Train/Val
**Symptom**: Model overfits, ignores augmentation benefit.
**User Says**: "I normalize images, so different augmentation pipelines okay"
**Why It Fails**: Train augmentation must be intentional; val must not have it.
**Fix**: Explicitly separate training and validation transforms.
### Pitfall 5: Ignoring Label Semantics
**Symptom**: Model predicts wrong class after augmentation.
**User Says**: "The label is preserved, so any transformation okay"
**Why It Fails**: Extreme transformations obscure discriminative features.
**Example**: Medical image rotated 180° may have artifacts that change diagnosis.
**Fix**: Consider label semantics, not just label preservation.
### Pitfall 6: No Augmentation on Small Dataset
**Symptom**: Severe overfitting, poor generalization.
**User Says**: "My data is unique, standard augmentations won't help"
**Why It Fails**: Overfitting still happens, augmentation reduces it.
**Fix**: Use domain-appropriate augmentations even on small datasets.
### Pitfall 7: Augmentation Not Reproducible
**Symptom**: Different training runs give different results.
**User Says**: "Random augmentation is fine, natural variation"
**Why It Fails**: Makes debugging impossible, non-reproducible research.
**Fix**: Set random seeds for reproducible augmentation.
```python
import random
import numpy as np
import torch
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
```
### Pitfall 8: Using One Augmentation Policy for All Tasks
**Symptom**: Augmentation works for classification, hurts for detection.
**User Says**: "Augmentation is general, works everywhere"
**Why It Fails**: Detection needs different augmentations (preserve boxes).
**Fix**: Domain AND task-specific augmentation selection.
### Pitfall 9: Augmentation Overhead Too High
**Symptom**: Training 2x slower, minimal accuracy improvement.
**User Says**: "Augmentation is worth the overhead"
**Why It Fails**: Sometimes it is, sometimes not. Measure impact.
**Fix**: Profile training time. Balance overhead vs accuracy gain.
### Pitfall 10: Mixing Incompatible Augmentations
**Symptom**: Unexpected behavior, degraded performance.
**User Says**: "Combining augmentations = better diversity"
**Why It Fails**: Some augmentations conflict or overlap.
**Example**: CutMix + random crop can create strange patches.
**Fix**: Design augmentation pipelines carefully, test combinations.
## Part 6: Augmentation Policy Design
### Step-by-Step Augmentation Design
**Step 1: Identify invariances in your domain**
What transformations preserve the class label?
- Vision: Rotation ±15° (natural), flip (depends), color jitter (yes)
- Text: Synonym replacement (yes), flip sentence (no)
- Audio: Pitch shift ±4 semitones (yes), time stretch ±20% (yes)
- Tabular: Feature noise (yes), feature permutation (no)
**Step 2: Select weak augmentations**
Choose conservative parameters.
```python
weak_aug = transforms.Compose([
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.1),
])
```
**Step 3: Measure impact**
Train with/without augmentation, compare validation accuracy.
```python
# Without augmentation
model_no_aug = train(no_aug_transforms, epochs=10)
val_acc_no_aug = evaluate(model_no_aug, val_loader)
# With weak augmentation
model_weak_aug = train(weak_aug, epochs=10)
val_acc_weak_aug = evaluate(model_weak_aug, val_loader)
print(f"Without augmentation: {val_acc_no_aug}")
print(f"With weak augmentation: {val_acc_weak_aug}")
```
**Step 4: Increase gradually if beneficial**
If augmentation helped, increase strength 25%.
```python
medium_aug = transforms.Compose([
transforms.RandomRotation(degrees=20), # ±20° vs ±15°
transforms.ColorJitter(brightness=0.15), # 0.15 vs 0.1
])
model_medium = train(medium_aug, epochs=10)
val_acc_medium = evaluate(model_medium, val_loader)
```
**Step 5: Stop when improvement plateaus**
When accuracy no longer improves, use previous best parameters.
### Augmentation for Different Dataset Sizes
**< 1,000 examples**: Heavy augmentation needed
```python
heavy_aug = transforms.Compose([
transforms.RandomRotation(degrees=30),
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.ColorJitter(brightness=0.3, contrast=0.3),
transforms.RandomAffine(degrees=0, shear=15),
transforms.RandomHorizontalFlip(p=0.5),
])
```
**1,000-10,000 examples**: Medium augmentation
```python
medium_aug = transforms.Compose([
transforms.RandomRotation(degrees=15),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.RandomHorizontalFlip(p=0.5),
])
```
**10,000-100,000 examples**: Light augmentation
```python
light_aug = transforms.Compose([
transforms.RandomRotation(degrees=10),
transforms.ColorJitter(brightness=0.1),
transforms.RandomHorizontalFlip(p=0.3),
])
```
**> 100,000 examples**: Minimal augmentation (optional)
```python
minimal_aug = transforms.Compose([
transforms.ColorJitter(brightness=0.05),
])
```
## Part 7: Augmentation Composition Strategies
### Sequential vs Compound Augmentation
**Sequential** (Apply transforms in sequence, each has independent probability):
```python
# Sequential: each transform independent
sequential = transforms.Compose([
transforms.RandomRotation(degrees=15), # 100% probability
transforms.ColorJitter(brightness=0.2), # 100% probability
transforms.RandomHorizontalFlip(p=0.5), # 50% probability
])
# Result: Always rotate and color jitter, sometimes flip
# Most common approach
```
**Compound** (Random selection of augmentation combinations):
```python
# Compound: choose one from alternatives
def compound_augmentation(image):
choice = np.random.choice(['light', 'medium', 'heavy'])
if choice == 'light':
return light_aug(image)
elif choice == 'medium':
return medium_aug(image)
else:
return heavy_aug(image)
```
**When to use compound**:
- When augmentations conflict
- When you want balanced diversity
- When computational resources limited
### Augmentation Order Matters
Some augmentations should be applied in specific order:
**Optimal order**:
1. Geometric transforms first (rotation, shear, perspective)
2. Cropping (RandomResizedCrop)
3. Flipping (horizontal, vertical)
4. Color/intensity transforms (brightness, contrast, hue)
5. Final normalization
```python
optimal_order = transforms.Compose([
transforms.RandomRotation(15),
transforms.RandomAffine(degrees=0, shear=10),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
```
**Why**: Geometric first (operate on pixel coordinates), then color (invariant to coordinate changes).
### Probability-Based Augmentation Control
**Weak augmentation** (apply to all data):
```python
# Weak: always apply
weak = transforms.Compose([
transforms.RandomRotation(degrees=10),
transforms.ColorJitter(brightness=0.1),
transforms.RandomHorizontalFlip(p=0.5),
])
# Apply to every training image
for epoch in range(epochs):
for images, labels in train_loader:
images = weak(images)
# ... train
```
**Strong augmentation with probability**:
```python
class ProbabilisticAugmentation:
def __init__(self, transform, p: float):
self.transform = transform
self.p = p
def __call__(self, x):
if np.random.rand() < self.p:
return self.transform(x)
return x
# Use strong augmentation with 30% probability
strong = transforms.Compose([
transforms.RandomRotation(degrees=45),
transforms.ColorJitter(brightness=0.4),
])
probabilistic = ProbabilisticAugmentation(strong, p=0.3)
# Each image: 70% unaugmented (training signal), 30% strongly augmented
```
## Part 8: Augmentation for Specific Tasks
### Augmentation for Object Detection
**Challenge**: Must preserve bounding boxes after augmentation.
**Strategy**: Use augmentations that preserve geometry or can remap boxes.
```python
from albumentations import (
HorizontalFlip, VerticalFlip, Rotate, ColorJitter, Resize, Compose
)
# Albumentations handles box remapping automatically
detection_augmentation = Compose([
HorizontalFlip(p=0.5),
Rotate(limit=15, p=0.5),
ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, p=0.5),
], bbox_params=BboxParams(format='pascal_voc', label_fields=['labels']))
# Usage:
image, boxes, labels = detection_sample
augmented = detection_augmentation(
image=image,
bboxes=boxes,
labels=labels
)
```
**Safe augmentations**:
- ✅ Horizontal flip (adjust box x-coordinates)
- ✅ Crop (clip boxes to cropped region)
- ✅ Rotate ±15° (remaps box corners)
- ✅ Color jitter (no box changes)
**Avoid**:
- ❌ Vertical flip (semantic meaning changes for many objects)
- ❌ Perspective distortion (complex box remapping)
- ❌ Large rotation (hard to remap boxes)
### Augmentation for Semantic Segmentation
**Challenge**: Masks must be transformed identically to images.
**Strategy**: Apply same transform to both image and mask.
```python
from albumentations import (
HorizontalFlip, RandomCrop, Rotate, ColorJitter, Compose
)
segmentation_augmentation = Compose([
HorizontalFlip(p=0.5),
Rotate(limit=15, p=0.5),
RandomCrop(height=256, width=256),
ColorJitter(brightness=0.2, contrast=0.2, p=0.5),
], keypoint_params=KeypointParams(format='xy'))
# Usage:
image, mask = segmentation_sample
augmented = segmentation_augmentation(image=image, mask=mask)
image_aug, mask_aug = augmented['image'], augmented['mask']
```
**Key requirement**: Image and mask transformed identically.
### Augmentation for Fine-Grained Classification
**Challenges**: Small objects, subtle differences between classes.
**Strategy**: Use conservative geometric transforms, aggressive color/texture.
```python
# Fine-grained: preserve structure, vary appearance
fine_grained = transforms.Compose([
transforms.RandomRotation(degrees=5), # Conservative rotation
transforms.RandomResizedCrop(224, scale=(0.9, 1.0)), # Minimal crop
transforms.ColorJitter(brightness=0.3, contrast=0.3), # Aggressive color
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
])
```
**Avoid**:
- Large crops (lose discriminative details)
- Extreme rotations (change object orientation)
- Perspective distortion (distorts fine structures)
### Augmentation for Medical Imaging
**Critical requirements**: Domain-specific, label-preserving, anatomically valid.
```python
# Medical imaging augmentation (conservative)
medical_aug = transforms.Compose([
transforms.RandomRotation(degrees=10), # Max ±10°
transforms.ColorJitter(brightness=0.1, contrast=0.1),
# Avoid: vertical flip (anatomical direction), excessive crop
])
# Never apply:
# - Vertical flip (anatomy has direction)
# - Random crops cutting off pathology
# - Extreme color transforms (diagnostic colors matter)
# - Perspective distortion (can distort anatomy)
```
**Domain-specific augmentations for medical**:
- ✅ Elastic deformation (models anatomical variation)
- ✅ Rotation ±10° (patient positioning variation)
- ✅ Small brightness/contrast (scanner variation)
- ✅ Gaussian blur (image quality variation)
### Augmentation for Time Series / Sequences
**For 1D sequences** (signal processing, ECG, EEG):
```python
def jitter(x: np.ndarray, std: float = 0.01):
"""Add small random noise to sequence."""
return x + np.random.normal(0, std, x.shape)
def scaling(x: np.ndarray, scale: float = 0.1):
"""Scale magnitude of sequence."""
return x * np.random.uniform(1 - scale, 1 + scale)
def rotation(x: np.ndarray):
"""Rotate in 2D space (for multivariate sequences)."""
theta = np.random.uniform(-np.pi/4, np.pi/4)
rotation_matrix = np.array([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])
return x @ rotation_matrix.T
def magnitude_warping(x: np.ndarray, sigma: float = 0.2):
"""Apply smooth scaling variations."""
knots = np.linspace(0, len(x), 5)
values = np.random.normal(1, sigma, len(knots))
from scipy.interpolate import interp1d
smooth_scale = interp1d(knots, values, kind='cubic')(np.arange(len(x)))
return x * smooth_scale[:, np.newaxis]
def window_slicing(x: np.ndarray, window_ratio: float = 0.1):
"""Reduce window size, then scale back to original length."""
window_size = int(len(x) * window_ratio)
start = np.random.randint(0, len(x) - window_size)
x_sliced = x[start:start + window_size]
# Interpolate back to original length
from scipy.interpolate import interp1d
f = interp1d(np.arange(len(x_sliced)), x_sliced, axis=0, kind='linear',
fill_value='extrapolate')
return f(np.linspace(0, len(x_sliced)-1, len(x)))
```
## Part 9: Augmentation Red Flags and Troubleshooting
### Red Flags: When Augmentation Is Hurting
1. **Validation accuracy DECREASES with augmentation**
- Likely: Too aggressive augmentation
- Solution: Reduce augmentation strength by 50%, retrain
2. **Training loss doesn't decrease**
- Likely: Images too distorted to learn
- Solution: Visualize augmented images, check if recognizable
3. **Test accuracy much worse than validation**
- Likely: Validation data accidentally augmented
- Solution: Check transform pipelines, ensure validation/test unaugmented
4. **High variance in results across runs**
- Likely: Augmentation randomness not seeded
- Solution: Set random seeds for reproducibility
5. **Specific class performance drops with augmentation**
- Likely: Augmentation inappropriate for that class
- Solution: Design class-specific augmentation (or disable for that class)
6. **Memory usage doubled**
- Likely: Applying augmentation twice (in data loader and training)
- Solution: Remove duplicate augmentation pipeline
7. **Model never converges to baseline**
- Likely: Augmentation too strong, label semantics lost
- Solution: Use weak augmentation first, increase gradually
8. **Overfitting still severe despite augmentation**
- Likely: Augmentation too weak or wrong type
- Solution: Increase strength, try different augmentations, use regularization too
### Troubleshooting Checklist
Before concluding augmentation doesn't help:
- [ ] Validation transform pipeline has NO augmentations
- [ ] Training transform pipeline has only desired augmentations
- [ ] Random seed set for reproducibility
- [ ] Augmented images are visually recognizable (not noise)
- [ ] Augmentation applied consistently across epochs
- [ ] Baseline training tested (no augmentation) for comparison
- [ ] Accuracy impact measured on same hardware/compute
- [ ] Computational cost justified by accuracy improvement
## Part 10: Rationalization Table (What Users Say vs Reality)
| User Statement | Reality | Evidence | Fix |
|----------------|---------|----------|-----|
| "Augmentation is overhead, skip it" | Augmentation prevents overfitting on small data | +5-10% accuracy on <5K examples | Enable augmentation, measure impact |
| "Use augmentation on validation too" | Validation measures true performance on original data | Metrics misleading if augmented | Remove augmentation from val transforms |
| "More augmentation always better" | Extreme augmentation creates label noise | Accuracy drops with too-aggressive transforms | Start conservative, increase gradually |
| "Same augmentation for all domains" | Each domain has different invariances | Text upside-down ≠ same class | Use domain-specific augmentations |
| "Augmentation takes too long" | ~10-20% training overhead, usually worth it | Depends on accuracy gain vs compute cost | Profile: measure accuracy/time tradeoff |
| "Flip works for everything" | Vertical flip changes anatomy/semantics | Medical imaging, some objects not symmetric | Know when flip is appropriate |
| "Random augmentation same as fixed" | Randomness prevents memorization, fixed is repetitive | Stochastic variation teaches invariance | Use random, not fixed transforms |
| "My data is too unique for standard augmentations" | Even unique data benefits from domain-appropriate augmentation | Overfitting still happens with small unique datasets | Adapt augmentations to your domain |
| "Augmentation is regularization" | Augmentation and regularization different; both help together | Dropout+BatchNorm+Augmentation > any single one | Use augmentation AND regularization |
| "TTA means augment validation" | TTA is optional post-training, not validation practice | TTA averaged over multiple forward passes | Use TTA only at final inference |
## Summary: Quick Reference
| Domain | Light Augmentations | Medium Augmentations | Strong Augmentations |
|--------|-------------------|----------------------|----------------------|
| Vision | ±10° rotation, ±10% brightness, 0.5 H-flip | ±20° rotation, ±20% brightness, CutMix | ±45° rotation, ±30% jitter, strong perspective |
| NLP | Synonym replacement (1 word) | Back-translation, EDA | Multiple paraphrases, sentence reordering |
| Audio | Pitch ±2 semitones, noise SNR 30dB | Pitch ±4, noise SNR 20dB | Pitch ±8, noise SNR 10dB |
| Tabular | Feature noise 1%, SMOTE | Feature noise 5%, feature dropout | Feature noise 10%, heavy SMOTE |
## Critical Rules
1. **Augment training data ONLY**. Validation and test data must be unaugmented.
2. **Start conservative, increase gradually**. Measure impact at each step.
3. **Domain matters**. No universal augmentation strategy exists.
4. **Preserve labels**. Do not apply transformations that change the class.
5. **Test incrementally**. Add one augmentation at a time, measure impact.
6. **Reproducibility**. Set random seeds for ablation studies.
7. **Avoid extremes**. If images/text unrecognizable, augmentation too strong.
8. **Know your domain**. Understand what invariances matter for your task.
9. **Measure impact**. Profile training time and accuracy improvement.
10. **Combine with regularization**. Augmentation works best with dropout, batch norm, weight decay.