Files
gh-tachyon-beep-skillpacks-…/skills/using-llm-specialist/llm-finetuning-strategies.md
2025-11-30 08:59:54 +08:00

23 KiB
Raw Blame History

LLM Fine-Tuning Strategies

Context

You're considering fine-tuning an LLM or debugging a fine-tuning process. Common mistakes:

  • Fine-tuning when prompts would work (unnecessary cost/time)
  • Full fine-tuning instead of LoRA (100× less efficient)
  • Poor dataset quality (garbage in, garbage out)
  • Wrong hyperparameters (catastrophic forgetting)
  • No validation strategy (overfitting undetected)

This skill provides effective fine-tuning strategies: when to fine-tune, efficient methods (LoRA), data quality, hyperparameters, and evaluation.

Decision Tree: Prompt Engineering vs Fine-Tuning

Start with prompt engineering. Fine-tuning is last resort.

Step 1: Try Prompt Engineering

# System message + few-shot examples
system = """
You are a {role} with {characteristics}.
{guidelines}
"""

few_shot = [
    # 3-5 examples of desired behavior
]

# Test quality
quality = evaluate(system, few_shot, test_set)

If quality ≥ 90%: STOP. Use prompts (no fine-tuning needed)

If quality < 90%: Continue to Step 2

Step 2: Optimize Prompts

  • Add more examples (5-10)
  • Add chain-of-thought
  • Specify output format more clearly
  • Try different system messages
  • Use temperature=0 for consistency

If quality ≥ 90%: STOP. Use optimized prompts

If quality < 90%: Continue to Step 3

Step 3: Consider Fine-Tuning

Fine-tune when:

Prompts fail (quality < 90% after optimization) Have 1000+ examples (minimum for meaningful fine-tuning) Need consistency (can't rely on prompt variations) Reduce latency (shorter prompts → faster inference) Teach new capability (not in base model)

Don't fine-tune for:

Tone/style matching (use system message) Output formatting (use format specification in prompt) Few examples (< 100 examples insufficient) Quick experiments (prompts iterate faster) Recent information (use RAG, not fine-tuning)

When to Fine-Tune: Detailed Criteria

Criterion 1: Task Complexity

Simple tasks (prompt engineering):

  • Classification (sentiment, category)
  • Extraction (entities, dates, names)
  • Formatting (JSON, CSV conversion)
  • Tone matching (company voice)

Complex tasks (consider fine-tuning):

  • Multi-step reasoning (not in base model)
  • Domain-specific language (medical, legal)
  • Consistent complex behavior (100+ edge cases)
  • New capabilities (teach entirely new skill)

Criterion 2: Dataset Size

< 100 examples: Prompts only (insufficient for fine-tuning)
100-1000: Prompts preferred (fine-tuning risky - overfitting)
1000-10k: Fine-tuning viable if prompts fail
> 10k: Fine-tuning effective

Criterion 3: Cost-Benefit

Prompt engineering:

  • Cost: $0 (just dev time)
  • Time: Minutes to hours (fast iteration)
  • Maintenance: Easy (just update prompt)

Fine-tuning:

  • Cost: $100-1000+ (compute + data prep)
  • Time: Days to weeks (data prep + training + eval)
  • Maintenance: Hard (need retraining for updates)

ROI calculation:

# Prompt engineering cost
prompt_dev_hours = 4
hourly_rate = 100
prompt_cost = 4 * 100 = $400

# Fine-tuning cost
data_prep_hours = 40
training_cost = 500
total_ft_cost = 40 * 100 + 500 = $4,500

# Cost ratio: Fine-tuning is 11× more expensive
# Only worth it if quality improvement > 10%

Criterion 4: Performance Requirements

Quality:

  • Need 90-95%: Prompts usually sufficient
  • Need 95-98%: Fine-tuning may help
  • Need 98%+: Fine-tuning + careful data curation

Latency:

  • 1 second acceptable: Prompts fine (long prompts OK)

  • 200-1000ms: Fine-tuning may help (reduce prompt size)
  • < 200ms: Fine-tuning + optimization required

Consistency:

  • Variable outputs acceptable: Prompts OK (temperature > 0)
  • High consistency needed: Prompts (temperature=0) or fine-tuning
  • Perfect consistency: Fine-tuning + validation

Fine-Tuning Methods

1. Full Fine-Tuning

Updates all model parameters.

Pros:

  • Maximum flexibility (can change any behavior)
  • Best quality (when you have massive data)

Cons:

  • Expensive (7B model = 28GB memory for weights alone)
  • Slow (hours to days)
  • Risk of catastrophic forgetting
  • Hard to merge multiple fine-tunes

When to use:

  • Massive dataset (100k+ examples)
  • Fundamental behavior change needed
  • Have large compute resources (multi-GPU)

Memory requirements:

# 7B parameter model (FP32)
weights = 7B * 4 bytes = 28 GB
gradients = 28 GB
optimizer_states = 56 GB (Adam: 2× weights)
activations = ~8 GB (batch_size=8)
total = 120 GB  # Need multi-GPU!

2. LoRA (Low-Rank Adaptation)

Freezes base model, trains small adapter matrices.

How it works:

Original linear layer: W (d × k)
LoRA: W + (A × B)
  where A (d × r), B (r × k), r << d,k

Example:
W: 4096 × 4096 = 16.7M parameters
A: 4096 × 8 = 32K parameters
B: 8 × 4096 = 32K parameters
A + B = 64K parameters (0.4% of original!)

Pros:

  • Extremely efficient (1% of parameters)
  • Fast training (10× faster than full FT)
  • Low memory (fits single GPU)
  • Easy to merge multiple LoRAs
  • No catastrophic forgetting (base model frozen)

Cons:

  • Slightly lower capacity than full FT (99% quality usually)
  • Need to keep base model + adapters

When to use:

  • 99% of fine-tuning cases
  • Limited compute (single GPU)
  • Fast iteration needed
  • Multiple tasks (train separate LoRAs, swap as needed)

Configuration:

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,  # Rank (4-16 typical, higher = more capacity)
    lora_alpha=32,  # Scaling (usually 2× rank)
    target_modules=["q_proj", "v_proj"],  # Which layers
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(base_model, config)
print(model.print_trainable_parameters())
# trainable params: 8.4M || all params: 7B || trainable%: 0.12%

Rank selection:

r=4: Minimal (fast, low capacity) - simple tasks
r=8: Standard (balanced) - most tasks
r=16: High capacity (slower, better quality) - complex tasks
r=32+: Approaching full FT quality (diminishing returns)

Start with r=8, increase only if quality insufficient

3. QLoRA (Quantized LoRA)

LoRA + 4-bit quantization of base model.

Pros:

  • Extremely memory efficient (4× less than LoRA)
  • 7B model fits on 16GB GPU
  • Same quality as LoRA

Cons:

  • Slower than LoRA (quantization overhead)
  • More complex setup

When to use:

  • Limited GPU memory (< 24GB)
  • Large models on consumer GPUs
  • Cost optimization (cheaper GPUs)

Setup:

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    device_map="auto"
)

# Then add LoRA as usual
model = get_peft_model(model, lora_config)

Memory comparison:

Method         | 7B Model | 13B Model | 70B Model
---------------|----------|-----------|----------
Full FT        | 120 GB   | 200 GB    | 1000 GB
LoRA           | 40 GB    | 60 GB     | 300 GB
QLoRA          | 12 GB    | 20 GB     | 80 GB

Method Selection:

if gpu_memory < 24:
    use_qlora()
elif gpu_memory < 80:
    use_lora()
elif have_massive_data and multi_gpu_cluster:
    use_full_finetuning()
else:
    use_lora()  # Default choice

Dataset Preparation

Quality > Quantity. 1,000 clean examples > 10,000 noisy examples.

1. Data Collection

Good sources:

  • Human-labeled data (gold standard)
  • Curated conversations (high-quality)
  • Expert-written examples
  • Validated user interactions

Bad sources:

  • Raw logs (errors, incomplete, noise)
  • Scraped data (quality varies wildly)
  • Automated generation (may have artifacts)
  • Untested user inputs (edge cases, adversarial)

2. Data Cleaning

def clean_dataset(raw_data):
    clean = []

    for example in raw_data:
        # Filter 1: Remove errors
        if any(err in example for err in ['error', 'exception', 'failed']):
            continue

        # Filter 2: Length checks
        if len(example['input']) < 10 or len(example['output']) < 10:
            continue  # Too short
        if len(example['input']) > 2000 or len(example['output']) > 2000:
            continue  # Too long (may be malformed)

        # Filter 3: Completeness
        if not example['output'].strip().endswith(('.', '!', '?')):
            continue  # Incomplete response

        # Filter 4: Language check
        if not is_valid_language(example['output']):
            continue  # Gibberish or wrong language

        # Filter 5: Duplicates
        if is_duplicate(example, clean):
            continue

        clean.append(example)

    return clean

cleaned = clean_dataset(raw_data)
print(f"Filtered: {len(raw_data)}{len(cleaned)}")
# Example: 10,000 → 3,000 (but high quality!)

3. Manual Validation

Critical step: Spot check 100+ random examples.

import random

sample = random.sample(cleaned, min(100, len(cleaned)))

for i, ex in enumerate(sample):
    print(f"\n--- Example {i+1}/100 ---")
    print(f"Input: {ex['input']}")
    print(f"Output: {ex['output']}")

    response = input("Quality (good/bad/skip)? ")
    if response == 'bad':
        # Investigate pattern, add filtering rule
        print("Why bad?")
        reason = input()
        # Update filtering logic

What to check:

  • ☐ Output is correct and complete
  • ☐ Output matches desired format/style
  • ☐ No errors or hallucinations
  • ☐ Appropriate length
  • ☐ Natural language (not robotic)
  • ☐ Consistent with other examples

4. Dataset Format

OpenAI format (for GPT fine-tuning):

{
  "messages": [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."}
  ]
}

Hugging Face format:

from datasets import Dataset

data = {
    'input': ["question 1", "question 2", ...],
    'output': ["answer 1", "answer 2", ...]
}

dataset = Dataset.from_dict(data)

5. Train/Val/Test Split

from sklearn.model_selection import train_test_split

# 70% train, 15% val, 15% test
train, temp = train_test_split(data, test_size=0.3, random_state=42)
val, test = train_test_split(temp, test_size=0.5, random_state=42)

print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")
# Example: Train: 2100, Val: 450, Test: 450

# Stratified split for imbalanced data
train, temp = train_test_split(
    data, test_size=0.3, stratify=data['label'], random_state=42
)

Split guidelines:

  • Minimum validation: 100 examples
  • Minimum test: 100 examples
  • Large datasets (> 10k): 80/10/10 split
  • Small datasets (< 5k): 70/15/15 split

6. Data Augmentation (Optional)

When you need more data:

# Paraphrasing
"What's the weather?"  "How's the weather today?"

# Back-translation
English  French  English (introduces variation)

# Synthetic generation (use carefully!)
few_shot_examples = [...]
new_examples = llm.generate(
    f"Generate 10 examples similar to: {few_shot_examples}"
)
# ALWAYS manually validate synthetic data!

Warning: Synthetic data can introduce artifacts. Always validate!

Hyperparameters

Learning Rate

Most critical hyperparameter.

# Pre-training LR: 1e-3 to 3e-4
# Fine-tuning LR: 100-1000× smaller!

training_args = TrainingArguments(
    learning_rate=1e-5,  # Start here for 7B models
    # Or even more conservative:
    learning_rate=1e-6,  # For larger models or small datasets
)

Guidelines:

Model size     | Pre-train LR | Fine-tune LR
---------------|--------------|-------------
1B params      | 3e-4         | 3e-5 to 1e-5
7B params      | 3e-4         | 1e-5 to 1e-6
13B params     | 2e-4         | 5e-6 to 1e-6
70B+ params    | 1e-4         | 1e-6 to 1e-7

Rule: Fine-tune LR ≈ Pre-train LR / 100

LR scheduling:

from transformers import get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=1e-5)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,  # Gradual LR increase (10% of training)
    num_training_steps=total_steps
)

Signs of wrong LR:

Too high (LR > 1e-4):

  • Training loss oscillates wildly
  • Model generates gibberish
  • Catastrophic forgetting (fails on general tasks)

Too low (LR < 1e-7):

  • Training loss barely decreases
  • Model doesn't adapt to new data
  • Very slow convergence

Epochs

training_args = TrainingArguments(
    num_train_epochs=3,  # Standard: 3-5 epochs
)

Guidelines:

Dataset size | Epochs
-------------|-------
< 1k         | 5-10 (more passes needed)
1k-5k        | 3-5 (standard)
5k-10k       | 2-3
> 10k        | 1-2 (large dataset, fewer passes)

Rule: Smaller dataset → more epochs (but watch for overfitting!)

Too many epochs:

  • Training loss → 0 but val loss increases (overfitting)
  • Model memorizes training data
  • Catastrophic forgetting

Too few epochs:

  • Model hasn't fully adapted
  • Training and val loss still decreasing

Batch Size

training_args = TrainingArguments(
    per_device_train_batch_size=8,  # Depends on GPU memory
    gradient_accumulation_steps=4,   # Effective batch = 8 × 4 = 32
)

Guidelines:

GPU Memory | Batch Size (7B model)
-----------|----------------------
16 GB      | 1-2 (use gradient accumulation!)
24 GB      | 2-4
40 GB      | 4-8
80 GB      | 8-16

Effective batch size (with accumulation): 16-64 typical

Gradient accumulation:

# Simulate batch_size=32 with only 8 examples fitting in memory:
per_device_train_batch_size=8
gradient_accumulation_steps=4
# Effective batch = 8 × 4 = 32

Weight Decay

training_args = TrainingArguments(
    weight_decay=0.01,  # L2 regularization (prevent overfitting)
)

Guidelines:

  • Standard: 0.01
  • Strong regularization: 0.1 (small dataset, high overfitting risk)
  • Light regularization: 0.001 (large dataset)

Warmup

training_args = TrainingArguments(
    warmup_steps=100,  # Or warmup_ratio=0.1 (10% of training)
)

Why warmup:

  • Prevents initial instability (large gradients early)
  • Gradual LR increase: 0 → target_LR over warmup steps

Guidelines:

  • Warmup: 5-10% of total training steps
  • Longer warmup for larger models

Training

Basic Training Loop

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",

    # Hyperparameters
    learning_rate=1e-5,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    weight_decay=0.01,
    warmup_steps=100,

    # Evaluation
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",

    # Logging
    logging_steps=10,
    logging_dir="./logs",

    # Optimization
    fp16=True,  # Mixed precision (faster, less memory)
    gradient_checkpointing=True,  # Trade compute for memory
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

trainer.train()

Monitoring Training

Key metrics to watch:

# 1. Training loss (should decrease steadily)
# 2. Validation loss (should decrease, then plateau)
# 3. Validation metrics (accuracy, F1, BLEU, etc.)

# Warning signs:
# - Train loss → 0 but val loss increasing: Overfitting
# - Train loss oscillating: LR too high
# - Train loss not decreasing: LR too low or data issues

Logging:

import wandb

wandb.init(project="fine-tuning")

training_args = TrainingArguments(
    report_to="wandb",  # Log to Weights & Biases
    logging_steps=10,
)

Early Stopping

from transformers import EarlyStoppingCallback

trainer = Trainer(
    ...
    callbacks=[EarlyStoppingCallback(
        early_stopping_patience=3,  # Stop if no improvement for 3 evals
        early_stopping_threshold=0.01,  # Minimum improvement
    )]
)

Why early stopping:

  • Prevents overfitting (stops before val loss increases)
  • Saves compute (don't train unnecessary epochs)
  • Automatically finds optimal epoch count

Evaluation

1. Validation During Training

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Decode predictions
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute metrics
    from sklearn.metrics import accuracy_score, f1_score
    accuracy = accuracy_score(decoded_labels, decoded_preds)
    f1 = f1_score(decoded_labels, decoded_preds, average='weighted')

    return {'accuracy': accuracy, 'f1': f1}

trainer = Trainer(
    ...
    compute_metrics=compute_metrics,
)

2. Test Set Evaluation (Final)

# After training completes, evaluate on held-out test set ONCE
test_results = trainer.evaluate(test_dataset)

print(f"Test accuracy: {test_results['accuracy']:.2%}")
print(f"Test F1: {test_results['f1']:.2%}")

3. Qualitative Evaluation

Critical: Manually test on real examples!

def test_model(model, tokenizer, test_examples):
    for ex in test_examples:
        prompt = ex['input']
        expected = ex['output']

        # Generate
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=100)
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)

        print(f"Input: {prompt}")
        print(f"Expected: {expected}")
        print(f"Generated: {generated}")
        print(f"Match: {'✓' if generated == expected else '✗'}")
        print("-" * 80)

# Test on 20-50 examples (including edge cases)
test_model(model, tokenizer, test_examples)

4. A/B Testing (Production)

# Route 50% traffic to base model, 50% to fine-tuned
import random

def get_model():
    if random.random() < 0.5:
        return base_model
    else:
        return finetuned_model

# Measure:
# - User satisfaction (thumbs up/down)
# - Task success rate
# - Response time
# - Cost per request

# After 1000+ requests, analyze results

5. Catastrophic Forgetting Check

Critical: Ensure fine-tuning didn't break base capabilities!

# Test on general knowledge tasks
general_tasks = [
    "What is the capital of France?",  # Basic knowledge
    "Translate to Spanish: Hello",    # Translation
    "2 + 2 = ?",                       # Basic math
    "Who wrote Hamlet?",               # Literature
]

for task in general_tasks:
    before = base_model.generate(task)
    after = finetuned_model.generate(task)

    print(f"Task: {task}")
    print(f"Before: {before}")
    print(f"After: {after}")
    print(f"Preserved: {'✓' if before == after else '✗'}")

Common Issues and Solutions

Issue 1: Overfitting

Symptoms:

  • Train loss → 0, val loss increases
  • Perfect on training data, poor on test data

Solutions:

# 1. Reduce epochs
num_train_epochs=3  # Instead of 10

# 2. Increase regularization
weight_decay=0.1  # Instead of 0.01

# 3. Early stopping
early_stopping_patience=3

# 4. Collect more data
# 5. Data augmentation

# 6. Use LoRA (less prone to overfitting than full FT)

Issue 2: Catastrophic Forgetting

Symptoms:

  • Fine-tuned model fails on general tasks
  • Lost pre-trained knowledge

Solutions:

# 1. Lower learning rate (most important!)
learning_rate=1e-6  # Instead of 1e-4

# 2. Fewer epochs
num_train_epochs=2  # Instead of 10

# 3. Use LoRA (base model frozen, can't forget)

# 4. Add general examples to training set (10-20% general data)

Issue 3: Poor Quality

Symptoms:

  • Model output is low quality (incorrect, incoherent)

Solutions:

# 1. Check dataset quality (most common cause!)
# - Manual validation
# - Remove noise
# - Fix labels

# 2. Increase model size
# - 7B → 13B → 70B

# 3. Increase training data
# - Need 1000+ high-quality examples

# 4. Adjust hyperparameters
# - Try higher LR (1e-5 → 3e-5) if underfit
# - Train longer (3 → 5 epochs)

# 5. Check if base model has capability
# - If base model can't do task, fine-tuning won't help

Issue 4: Slow Training

Symptoms:

  • Training takes days/weeks

Solutions:

# 1. Use LoRA (10× faster than full FT)

# 2. Mixed precision
fp16=True  # 2× faster

# 3. Gradient checkpointing (trade speed for memory)
gradient_checkpointing=True

# 4. Smaller batch size + gradient accumulation
per_device_train_batch_size=2
gradient_accumulation_steps=16

# 5. Use multiple GPUs
# 6. Use faster GPU (A100 > V100 > T4)

Issue 5: Out of Memory

Symptoms:

  • CUDA out of memory error

Solutions:

# 1. Use QLoRA (4× less memory)

# 2. Reduce batch size
per_device_train_batch_size=1
gradient_accumulation_steps=32

# 3. Gradient checkpointing
gradient_checkpointing=True

# 4. Use smaller model
# 7B → 3B → 1B

# 5. Reduce sequence length
max_seq_length=512  # Instead of 2048

Best Practices Summary

Before Fine-Tuning:

  1. ☐ Try prompt engineering first (90% of cases, prompts work!)
  2. ☐ Have 1000+ high-quality examples
  3. ☐ Clean and validate dataset (quality > quantity)
  4. ☐ Create train/val/test split (70/15/15)
  5. ☐ Define success metrics (what does "good" mean?)

During Fine-Tuning:

  1. ☐ Use LoRA (unless specific reason for full FT)
  2. ☐ Set tiny learning rate (1e-5 to 1e-6 for 7B models)
  3. ☐ Train for 3-5 epochs (not 50!)
  4. ☐ Monitor val loss (stop when it stops improving)
  5. ☐ Log everything (wandb, tensorboard)

After Fine-Tuning:

  1. ☐ Evaluate on test set (quantitative metrics)
  2. ☐ Manual testing (qualitative, 20-50 examples)
  3. ☐ Check for catastrophic forgetting (general tasks)
  4. ☐ A/B test in production (before full rollout)
  5. ☐ Document hyperparameters (for reproducibility)

Quick Reference

Task Method Dataset LR Epochs
Tone matching Prompts N/A N/A N/A
Simple classification Prompts N/A N/A N/A
Complex domain task LoRA 1k-10k 1e-5 3-5
Fundamental change Full FT 100k+ 1e-5 1-3
Limited GPU QLoRA 1k-10k 1e-5 3-5

Default recommendation: Try prompts first. If that fails, use LoRA with LR=1e-5, epochs=3, and high-quality dataset.

Summary

Core principles:

  1. Prompt engineering first: 90% of tasks don't need fine-tuning
  2. LoRA by default: 100× more efficient than full fine-tuning, same quality
  3. Data quality matters: 1,000 clean examples > 10,000 noisy examples
  4. Tiny learning rate: Fine-tune LR = Pre-train LR / 100 to / 1000
  5. Validation essential: Train/val/test split + early stopping + catastrophic forgetting check

Decision tree:

  1. Try prompts (system message + few-shot)
  2. If quality < 90%, optimize prompts
  3. If still < 90% and have 1000+ examples, consider fine-tuning
  4. Use LoRA (default), QLoRA (limited GPU), or full FT (rare)
  5. Set LR = 1e-5, epochs = 3-5, monitor val loss
  6. Evaluate on test set + manual testing + general tasks

Key insight: Fine-tuning is powerful but expensive and slow. Start with prompts, fine-tune only when prompts demonstrably fail and you have high-quality data.