23 KiB
LLM Fine-Tuning Strategies
Context
You're considering fine-tuning an LLM or debugging a fine-tuning process. Common mistakes:
- Fine-tuning when prompts would work (unnecessary cost/time)
- Full fine-tuning instead of LoRA (100× less efficient)
- Poor dataset quality (garbage in, garbage out)
- Wrong hyperparameters (catastrophic forgetting)
- No validation strategy (overfitting undetected)
This skill provides effective fine-tuning strategies: when to fine-tune, efficient methods (LoRA), data quality, hyperparameters, and evaluation.
Decision Tree: Prompt Engineering vs Fine-Tuning
Start with prompt engineering. Fine-tuning is last resort.
Step 1: Try Prompt Engineering
# System message + few-shot examples
system = """
You are a {role} with {characteristics}.
{guidelines}
"""
few_shot = [
# 3-5 examples of desired behavior
]
# Test quality
quality = evaluate(system, few_shot, test_set)
If quality ≥ 90%: ✅ STOP. Use prompts (no fine-tuning needed)
If quality < 90%: Continue to Step 2
Step 2: Optimize Prompts
- Add more examples (5-10)
- Add chain-of-thought
- Specify output format more clearly
- Try different system messages
- Use temperature=0 for consistency
If quality ≥ 90%: ✅ STOP. Use optimized prompts
If quality < 90%: Continue to Step 3
Step 3: Consider Fine-Tuning
Fine-tune when:
✅ Prompts fail (quality < 90% after optimization) ✅ Have 1000+ examples (minimum for meaningful fine-tuning) ✅ Need consistency (can't rely on prompt variations) ✅ Reduce latency (shorter prompts → faster inference) ✅ Teach new capability (not in base model)
Don't fine-tune for:
❌ Tone/style matching (use system message) ❌ Output formatting (use format specification in prompt) ❌ Few examples (< 100 examples insufficient) ❌ Quick experiments (prompts iterate faster) ❌ Recent information (use RAG, not fine-tuning)
When to Fine-Tune: Detailed Criteria
Criterion 1: Task Complexity
Simple tasks (prompt engineering):
- Classification (sentiment, category)
- Extraction (entities, dates, names)
- Formatting (JSON, CSV conversion)
- Tone matching (company voice)
Complex tasks (consider fine-tuning):
- Multi-step reasoning (not in base model)
- Domain-specific language (medical, legal)
- Consistent complex behavior (100+ edge cases)
- New capabilities (teach entirely new skill)
Criterion 2: Dataset Size
< 100 examples: Prompts only (insufficient for fine-tuning)
100-1000: Prompts preferred (fine-tuning risky - overfitting)
1000-10k: Fine-tuning viable if prompts fail
> 10k: Fine-tuning effective
Criterion 3: Cost-Benefit
Prompt engineering:
- Cost: $0 (just dev time)
- Time: Minutes to hours (fast iteration)
- Maintenance: Easy (just update prompt)
Fine-tuning:
- Cost: $100-1000+ (compute + data prep)
- Time: Days to weeks (data prep + training + eval)
- Maintenance: Hard (need retraining for updates)
ROI calculation:
# Prompt engineering cost
prompt_dev_hours = 4
hourly_rate = 100
prompt_cost = 4 * 100 = $400
# Fine-tuning cost
data_prep_hours = 40
training_cost = 500
total_ft_cost = 40 * 100 + 500 = $4,500
# Cost ratio: Fine-tuning is 11× more expensive
# Only worth it if quality improvement > 10%
Criterion 4: Performance Requirements
Quality:
- Need 90-95%: Prompts usually sufficient
- Need 95-98%: Fine-tuning may help
- Need 98%+: Fine-tuning + careful data curation
Latency:
-
1 second acceptable: Prompts fine (long prompts OK)
- 200-1000ms: Fine-tuning may help (reduce prompt size)
- < 200ms: Fine-tuning + optimization required
Consistency:
- Variable outputs acceptable: Prompts OK (temperature > 0)
- High consistency needed: Prompts (temperature=0) or fine-tuning
- Perfect consistency: Fine-tuning + validation
Fine-Tuning Methods
1. Full Fine-Tuning
Updates all model parameters.
Pros:
- Maximum flexibility (can change any behavior)
- Best quality (when you have massive data)
Cons:
- Expensive (7B model = 28GB memory for weights alone)
- Slow (hours to days)
- Risk of catastrophic forgetting
- Hard to merge multiple fine-tunes
When to use:
- Massive dataset (100k+ examples)
- Fundamental behavior change needed
- Have large compute resources (multi-GPU)
Memory requirements:
# 7B parameter model (FP32)
weights = 7B * 4 bytes = 28 GB
gradients = 28 GB
optimizer_states = 56 GB (Adam: 2× weights)
activations = ~8 GB (batch_size=8)
total = 120 GB # Need multi-GPU!
2. LoRA (Low-Rank Adaptation)
Freezes base model, trains small adapter matrices.
How it works:
Original linear layer: W (d × k)
LoRA: W + (A × B)
where A (d × r), B (r × k), r << d,k
Example:
W: 4096 × 4096 = 16.7M parameters
A: 4096 × 8 = 32K parameters
B: 8 × 4096 = 32K parameters
A + B = 64K parameters (0.4% of original!)
Pros:
- Extremely efficient (1% of parameters)
- Fast training (10× faster than full FT)
- Low memory (fits single GPU)
- Easy to merge multiple LoRAs
- No catastrophic forgetting (base model frozen)
Cons:
- Slightly lower capacity than full FT (99% quality usually)
- Need to keep base model + adapters
When to use:
- 99% of fine-tuning cases
- Limited compute (single GPU)
- Fast iteration needed
- Multiple tasks (train separate LoRAs, swap as needed)
Configuration:
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=8, # Rank (4-16 typical, higher = more capacity)
lora_alpha=32, # Scaling (usually 2× rank)
target_modules=["q_proj", "v_proj"], # Which layers
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(base_model, config)
print(model.print_trainable_parameters())
# trainable params: 8.4M || all params: 7B || trainable%: 0.12%
Rank selection:
r=4: Minimal (fast, low capacity) - simple tasks
r=8: Standard (balanced) - most tasks
r=16: High capacity (slower, better quality) - complex tasks
r=32+: Approaching full FT quality (diminishing returns)
Start with r=8, increase only if quality insufficient
3. QLoRA (Quantized LoRA)
LoRA + 4-bit quantization of base model.
Pros:
- Extremely memory efficient (4× less than LoRA)
- 7B model fits on 16GB GPU
- Same quality as LoRA
Cons:
- Slower than LoRA (quantization overhead)
- More complex setup
When to use:
- Limited GPU memory (< 24GB)
- Large models on consumer GPUs
- Cost optimization (cheaper GPUs)
Setup:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
device_map="auto"
)
# Then add LoRA as usual
model = get_peft_model(model, lora_config)
Memory comparison:
Method | 7B Model | 13B Model | 70B Model
---------------|----------|-----------|----------
Full FT | 120 GB | 200 GB | 1000 GB
LoRA | 40 GB | 60 GB | 300 GB
QLoRA | 12 GB | 20 GB | 80 GB
Method Selection:
if gpu_memory < 24:
use_qlora()
elif gpu_memory < 80:
use_lora()
elif have_massive_data and multi_gpu_cluster:
use_full_finetuning()
else:
use_lora() # Default choice
Dataset Preparation
Quality > Quantity. 1,000 clean examples > 10,000 noisy examples.
1. Data Collection
Good sources:
- Human-labeled data (gold standard)
- Curated conversations (high-quality)
- Expert-written examples
- Validated user interactions
Bad sources:
- Raw logs (errors, incomplete, noise)
- Scraped data (quality varies wildly)
- Automated generation (may have artifacts)
- Untested user inputs (edge cases, adversarial)
2. Data Cleaning
def clean_dataset(raw_data):
clean = []
for example in raw_data:
# Filter 1: Remove errors
if any(err in example for err in ['error', 'exception', 'failed']):
continue
# Filter 2: Length checks
if len(example['input']) < 10 or len(example['output']) < 10:
continue # Too short
if len(example['input']) > 2000 or len(example['output']) > 2000:
continue # Too long (may be malformed)
# Filter 3: Completeness
if not example['output'].strip().endswith(('.', '!', '?')):
continue # Incomplete response
# Filter 4: Language check
if not is_valid_language(example['output']):
continue # Gibberish or wrong language
# Filter 5: Duplicates
if is_duplicate(example, clean):
continue
clean.append(example)
return clean
cleaned = clean_dataset(raw_data)
print(f"Filtered: {len(raw_data)} → {len(cleaned)}")
# Example: 10,000 → 3,000 (but high quality!)
3. Manual Validation
Critical step: Spot check 100+ random examples.
import random
sample = random.sample(cleaned, min(100, len(cleaned)))
for i, ex in enumerate(sample):
print(f"\n--- Example {i+1}/100 ---")
print(f"Input: {ex['input']}")
print(f"Output: {ex['output']}")
response = input("Quality (good/bad/skip)? ")
if response == 'bad':
# Investigate pattern, add filtering rule
print("Why bad?")
reason = input()
# Update filtering logic
What to check:
- ☐ Output is correct and complete
- ☐ Output matches desired format/style
- ☐ No errors or hallucinations
- ☐ Appropriate length
- ☐ Natural language (not robotic)
- ☐ Consistent with other examples
4. Dataset Format
OpenAI format (for GPT fine-tuning):
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris."}
]
}
Hugging Face format:
from datasets import Dataset
data = {
'input': ["question 1", "question 2", ...],
'output': ["answer 1", "answer 2", ...]
}
dataset = Dataset.from_dict(data)
5. Train/Val/Test Split
from sklearn.model_selection import train_test_split
# 70% train, 15% val, 15% test
train, temp = train_test_split(data, test_size=0.3, random_state=42)
val, test = train_test_split(temp, test_size=0.5, random_state=42)
print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")
# Example: Train: 2100, Val: 450, Test: 450
# Stratified split for imbalanced data
train, temp = train_test_split(
data, test_size=0.3, stratify=data['label'], random_state=42
)
Split guidelines:
- Minimum validation: 100 examples
- Minimum test: 100 examples
- Large datasets (> 10k): 80/10/10 split
- Small datasets (< 5k): 70/15/15 split
6. Data Augmentation (Optional)
When you need more data:
# Paraphrasing
"What's the weather?" → "How's the weather today?"
# Back-translation
English → French → English (introduces variation)
# Synthetic generation (use carefully!)
few_shot_examples = [...]
new_examples = llm.generate(
f"Generate 10 examples similar to: {few_shot_examples}"
)
# ALWAYS manually validate synthetic data!
Warning: Synthetic data can introduce artifacts. Always validate!
Hyperparameters
Learning Rate
Most critical hyperparameter.
# Pre-training LR: 1e-3 to 3e-4
# Fine-tuning LR: 100-1000× smaller!
training_args = TrainingArguments(
learning_rate=1e-5, # Start here for 7B models
# Or even more conservative:
learning_rate=1e-6, # For larger models or small datasets
)
Guidelines:
Model size | Pre-train LR | Fine-tune LR
---------------|--------------|-------------
1B params | 3e-4 | 3e-5 to 1e-5
7B params | 3e-4 | 1e-5 to 1e-6
13B params | 2e-4 | 5e-6 to 1e-6
70B+ params | 1e-4 | 1e-6 to 1e-7
Rule: Fine-tune LR ≈ Pre-train LR / 100
LR scheduling:
from transformers import get_linear_schedule_with_warmup
optimizer = AdamW(model.parameters(), lr=1e-5)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=100, # Gradual LR increase (10% of training)
num_training_steps=total_steps
)
Signs of wrong LR:
Too high (LR > 1e-4):
- Training loss oscillates wildly
- Model generates gibberish
- Catastrophic forgetting (fails on general tasks)
Too low (LR < 1e-7):
- Training loss barely decreases
- Model doesn't adapt to new data
- Very slow convergence
Epochs
training_args = TrainingArguments(
num_train_epochs=3, # Standard: 3-5 epochs
)
Guidelines:
Dataset size | Epochs
-------------|-------
< 1k | 5-10 (more passes needed)
1k-5k | 3-5 (standard)
5k-10k | 2-3
> 10k | 1-2 (large dataset, fewer passes)
Rule: Smaller dataset → more epochs (but watch for overfitting!)
Too many epochs:
- Training loss → 0 but val loss increases (overfitting)
- Model memorizes training data
- Catastrophic forgetting
Too few epochs:
- Model hasn't fully adapted
- Training and val loss still decreasing
Batch Size
training_args = TrainingArguments(
per_device_train_batch_size=8, # Depends on GPU memory
gradient_accumulation_steps=4, # Effective batch = 8 × 4 = 32
)
Guidelines:
GPU Memory | Batch Size (7B model)
-----------|----------------------
16 GB | 1-2 (use gradient accumulation!)
24 GB | 2-4
40 GB | 4-8
80 GB | 8-16
Effective batch size (with accumulation): 16-64 typical
Gradient accumulation:
# Simulate batch_size=32 with only 8 examples fitting in memory:
per_device_train_batch_size=8
gradient_accumulation_steps=4
# Effective batch = 8 × 4 = 32
Weight Decay
training_args = TrainingArguments(
weight_decay=0.01, # L2 regularization (prevent overfitting)
)
Guidelines:
- Standard: 0.01
- Strong regularization: 0.1 (small dataset, high overfitting risk)
- Light regularization: 0.001 (large dataset)
Warmup
training_args = TrainingArguments(
warmup_steps=100, # Or warmup_ratio=0.1 (10% of training)
)
Why warmup:
- Prevents initial instability (large gradients early)
- Gradual LR increase: 0 → target_LR over warmup steps
Guidelines:
- Warmup: 5-10% of total training steps
- Longer warmup for larger models
Training
Basic Training Loop
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir="./results",
# Hyperparameters
learning_rate=1e-5,
num_train_epochs=3,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
weight_decay=0.01,
warmup_steps=100,
# Evaluation
evaluation_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
# Logging
logging_steps=10,
logging_dir="./logs",
# Optimization
fp16=True, # Mixed precision (faster, less memory)
gradient_checkpointing=True, # Trade compute for memory
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
trainer.train()
Monitoring Training
Key metrics to watch:
# 1. Training loss (should decrease steadily)
# 2. Validation loss (should decrease, then plateau)
# 3. Validation metrics (accuracy, F1, BLEU, etc.)
# Warning signs:
# - Train loss → 0 but val loss increasing: Overfitting
# - Train loss oscillating: LR too high
# - Train loss not decreasing: LR too low or data issues
Logging:
import wandb
wandb.init(project="fine-tuning")
training_args = TrainingArguments(
report_to="wandb", # Log to Weights & Biases
logging_steps=10,
)
Early Stopping
from transformers import EarlyStoppingCallback
trainer = Trainer(
...
callbacks=[EarlyStoppingCallback(
early_stopping_patience=3, # Stop if no improvement for 3 evals
early_stopping_threshold=0.01, # Minimum improvement
)]
)
Why early stopping:
- Prevents overfitting (stops before val loss increases)
- Saves compute (don't train unnecessary epochs)
- Automatically finds optimal epoch count
Evaluation
1. Validation During Training
def compute_metrics(eval_pred):
predictions, labels = eval_pred
# Decode predictions
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Compute metrics
from sklearn.metrics import accuracy_score, f1_score
accuracy = accuracy_score(decoded_labels, decoded_preds)
f1 = f1_score(decoded_labels, decoded_preds, average='weighted')
return {'accuracy': accuracy, 'f1': f1}
trainer = Trainer(
...
compute_metrics=compute_metrics,
)
2. Test Set Evaluation (Final)
# After training completes, evaluate on held-out test set ONCE
test_results = trainer.evaluate(test_dataset)
print(f"Test accuracy: {test_results['accuracy']:.2%}")
print(f"Test F1: {test_results['f1']:.2%}")
3. Qualitative Evaluation
Critical: Manually test on real examples!
def test_model(model, tokenizer, test_examples):
for ex in test_examples:
prompt = ex['input']
expected = ex['output']
# Generate
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Input: {prompt}")
print(f"Expected: {expected}")
print(f"Generated: {generated}")
print(f"Match: {'✓' if generated == expected else '✗'}")
print("-" * 80)
# Test on 20-50 examples (including edge cases)
test_model(model, tokenizer, test_examples)
4. A/B Testing (Production)
# Route 50% traffic to base model, 50% to fine-tuned
import random
def get_model():
if random.random() < 0.5:
return base_model
else:
return finetuned_model
# Measure:
# - User satisfaction (thumbs up/down)
# - Task success rate
# - Response time
# - Cost per request
# After 1000+ requests, analyze results
5. Catastrophic Forgetting Check
Critical: Ensure fine-tuning didn't break base capabilities!
# Test on general knowledge tasks
general_tasks = [
"What is the capital of France?", # Basic knowledge
"Translate to Spanish: Hello", # Translation
"2 + 2 = ?", # Basic math
"Who wrote Hamlet?", # Literature
]
for task in general_tasks:
before = base_model.generate(task)
after = finetuned_model.generate(task)
print(f"Task: {task}")
print(f"Before: {before}")
print(f"After: {after}")
print(f"Preserved: {'✓' if before == after else '✗'}")
Common Issues and Solutions
Issue 1: Overfitting
Symptoms:
- Train loss → 0, val loss increases
- Perfect on training data, poor on test data
Solutions:
# 1. Reduce epochs
num_train_epochs=3 # Instead of 10
# 2. Increase regularization
weight_decay=0.1 # Instead of 0.01
# 3. Early stopping
early_stopping_patience=3
# 4. Collect more data
# 5. Data augmentation
# 6. Use LoRA (less prone to overfitting than full FT)
Issue 2: Catastrophic Forgetting
Symptoms:
- Fine-tuned model fails on general tasks
- Lost pre-trained knowledge
Solutions:
# 1. Lower learning rate (most important!)
learning_rate=1e-6 # Instead of 1e-4
# 2. Fewer epochs
num_train_epochs=2 # Instead of 10
# 3. Use LoRA (base model frozen, can't forget)
# 4. Add general examples to training set (10-20% general data)
Issue 3: Poor Quality
Symptoms:
- Model output is low quality (incorrect, incoherent)
Solutions:
# 1. Check dataset quality (most common cause!)
# - Manual validation
# - Remove noise
# - Fix labels
# 2. Increase model size
# - 7B → 13B → 70B
# 3. Increase training data
# - Need 1000+ high-quality examples
# 4. Adjust hyperparameters
# - Try higher LR (1e-5 → 3e-5) if underfit
# - Train longer (3 → 5 epochs)
# 5. Check if base model has capability
# - If base model can't do task, fine-tuning won't help
Issue 4: Slow Training
Symptoms:
- Training takes days/weeks
Solutions:
# 1. Use LoRA (10× faster than full FT)
# 2. Mixed precision
fp16=True # 2× faster
# 3. Gradient checkpointing (trade speed for memory)
gradient_checkpointing=True
# 4. Smaller batch size + gradient accumulation
per_device_train_batch_size=2
gradient_accumulation_steps=16
# 5. Use multiple GPUs
# 6. Use faster GPU (A100 > V100 > T4)
Issue 5: Out of Memory
Symptoms:
- CUDA out of memory error
Solutions:
# 1. Use QLoRA (4× less memory)
# 2. Reduce batch size
per_device_train_batch_size=1
gradient_accumulation_steps=32
# 3. Gradient checkpointing
gradient_checkpointing=True
# 4. Use smaller model
# 7B → 3B → 1B
# 5. Reduce sequence length
max_seq_length=512 # Instead of 2048
Best Practices Summary
Before Fine-Tuning:
- ☐ Try prompt engineering first (90% of cases, prompts work!)
- ☐ Have 1000+ high-quality examples
- ☐ Clean and validate dataset (quality > quantity)
- ☐ Create train/val/test split (70/15/15)
- ☐ Define success metrics (what does "good" mean?)
During Fine-Tuning:
- ☐ Use LoRA (unless specific reason for full FT)
- ☐ Set tiny learning rate (1e-5 to 1e-6 for 7B models)
- ☐ Train for 3-5 epochs (not 50!)
- ☐ Monitor val loss (stop when it stops improving)
- ☐ Log everything (wandb, tensorboard)
After Fine-Tuning:
- ☐ Evaluate on test set (quantitative metrics)
- ☐ Manual testing (qualitative, 20-50 examples)
- ☐ Check for catastrophic forgetting (general tasks)
- ☐ A/B test in production (before full rollout)
- ☐ Document hyperparameters (for reproducibility)
Quick Reference
| Task | Method | Dataset | LR | Epochs |
|---|---|---|---|---|
| Tone matching | Prompts | N/A | N/A | N/A |
| Simple classification | Prompts | N/A | N/A | N/A |
| Complex domain task | LoRA | 1k-10k | 1e-5 | 3-5 |
| Fundamental change | Full FT | 100k+ | 1e-5 | 1-3 |
| Limited GPU | QLoRA | 1k-10k | 1e-5 | 3-5 |
Default recommendation: Try prompts first. If that fails, use LoRA with LR=1e-5, epochs=3, and high-quality dataset.
Summary
Core principles:
- Prompt engineering first: 90% of tasks don't need fine-tuning
- LoRA by default: 100× more efficient than full fine-tuning, same quality
- Data quality matters: 1,000 clean examples > 10,000 noisy examples
- Tiny learning rate: Fine-tune LR = Pre-train LR / 100 to / 1000
- Validation essential: Train/val/test split + early stopping + catastrophic forgetting check
Decision tree:
- Try prompts (system message + few-shot)
- If quality < 90%, optimize prompts
- If still < 90% and have 1000+ examples, consider fine-tuning
- Use LoRA (default), QLoRA (limited GPU), or full FT (rare)
- Set LR = 1e-5, epochs = 3-5, monitor val loss
- Evaluate on test set + manual testing + general tasks
Key insight: Fine-tuning is powerful but expensive and slow. Start with prompts, fine-tune only when prompts demonstrably fail and you have high-quality data.