# LLM Fine-Tuning Strategies ## Context You're considering fine-tuning an LLM or debugging a fine-tuning process. Common mistakes: - **Fine-tuning when prompts would work** (unnecessary cost/time) - **Full fine-tuning instead of LoRA** (100× less efficient) - **Poor dataset quality** (garbage in, garbage out) - **Wrong hyperparameters** (catastrophic forgetting) - **No validation strategy** (overfitting undetected) **This skill provides effective fine-tuning strategies: when to fine-tune, efficient methods (LoRA), data quality, hyperparameters, and evaluation.** ## Decision Tree: Prompt Engineering vs Fine-Tuning **Start with prompt engineering. Fine-tuning is last resort.** ### Step 1: Try Prompt Engineering ```python # System message + few-shot examples system = """ You are a {role} with {characteristics}. {guidelines} """ few_shot = [ # 3-5 examples of desired behavior ] # Test quality quality = evaluate(system, few_shot, test_set) ``` **If quality ≥ 90%:** ✅ STOP. Use prompts (no fine-tuning needed) **If quality < 90%:** Continue to Step 2 ### Step 2: Optimize Prompts - Add more examples (5-10) - Add chain-of-thought - Specify output format more clearly - Try different system messages - Use temperature=0 for consistency **If quality ≥ 90%:** ✅ STOP. Use optimized prompts **If quality < 90%:** Continue to Step 3 ### Step 3: Consider Fine-Tuning **Fine-tune when:** ✅ **Prompts fail** (quality < 90% after optimization) ✅ **Have 1000+ examples** (minimum for meaningful fine-tuning) ✅ **Need consistency** (can't rely on prompt variations) ✅ **Reduce latency** (shorter prompts → faster inference) ✅ **Teach new capability** (not in base model) **Don't fine-tune for:** ❌ **Tone/style matching** (use system message) ❌ **Output formatting** (use format specification in prompt) ❌ **Few examples** (< 100 examples insufficient) ❌ **Quick experiments** (prompts iterate faster) ❌ **Recent information** (use RAG, not fine-tuning) ## When to Fine-Tune: Detailed Criteria ### Criterion 1: Task Complexity **Simple tasks (prompt engineering):** - Classification (sentiment, category) - Extraction (entities, dates, names) - Formatting (JSON, CSV conversion) - Tone matching (company voice) **Complex tasks (consider fine-tuning):** - Multi-step reasoning (not in base model) - Domain-specific language (medical, legal) - Consistent complex behavior (100+ edge cases) - New capabilities (teach entirely new skill) ### Criterion 2: Dataset Size ``` < 100 examples: Prompts only (insufficient for fine-tuning) 100-1000: Prompts preferred (fine-tuning risky - overfitting) 1000-10k: Fine-tuning viable if prompts fail > 10k: Fine-tuning effective ``` ### Criterion 3: Cost-Benefit **Prompt engineering:** - Cost: $0 (just dev time) - Time: Minutes to hours (fast iteration) - Maintenance: Easy (just update prompt) **Fine-tuning:** - Cost: $100-1000+ (compute + data prep) - Time: Days to weeks (data prep + training + eval) - Maintenance: Hard (need retraining for updates) **ROI calculation:** ```python # Prompt engineering cost prompt_dev_hours = 4 hourly_rate = 100 prompt_cost = 4 * 100 = $400 # Fine-tuning cost data_prep_hours = 40 training_cost = 500 total_ft_cost = 40 * 100 + 500 = $4,500 # Cost ratio: Fine-tuning is 11× more expensive # Only worth it if quality improvement > 10% ``` ### Criterion 4: Performance Requirements **Quality:** - Need 90-95%: Prompts usually sufficient - Need 95-98%: Fine-tuning may help - Need 98%+: Fine-tuning + careful data curation **Latency:** - > 1 second acceptable: Prompts fine (long prompts OK) - 200-1000ms: Fine-tuning may help (reduce prompt size) - < 200ms: Fine-tuning + optimization required **Consistency:** - Variable outputs acceptable: Prompts OK (temperature > 0) - High consistency needed: Prompts (temperature=0) or fine-tuning - Perfect consistency: Fine-tuning + validation ## Fine-Tuning Methods ### 1. Full Fine-Tuning **Updates all model parameters.** **Pros:** - Maximum flexibility (can change any behavior) - Best quality (when you have massive data) **Cons:** - Expensive (7B model = 28GB memory for weights alone) - Slow (hours to days) - Risk of catastrophic forgetting - Hard to merge multiple fine-tunes **When to use:** - Massive dataset (100k+ examples) - Fundamental behavior change needed - Have large compute resources (multi-GPU) **Memory requirements:** ```python # 7B parameter model (FP32) weights = 7B * 4 bytes = 28 GB gradients = 28 GB optimizer_states = 56 GB (Adam: 2× weights) activations = ~8 GB (batch_size=8) total = 120 GB # Need multi-GPU! ``` ### 2. LoRA (Low-Rank Adaptation) **Freezes base model, trains small adapter matrices.** **How it works:** ``` Original linear layer: W (d × k) LoRA: W + (A × B) where A (d × r), B (r × k), r << d,k Example: W: 4096 × 4096 = 16.7M parameters A: 4096 × 8 = 32K parameters B: 8 × 4096 = 32K parameters A + B = 64K parameters (0.4% of original!) ``` **Pros:** - Extremely efficient (1% of parameters) - Fast training (10× faster than full FT) - Low memory (fits single GPU) - Easy to merge multiple LoRAs - No catastrophic forgetting (base model frozen) **Cons:** - Slightly lower capacity than full FT (99% quality usually) - Need to keep base model + adapters **When to use:** - 99% of fine-tuning cases - Limited compute (single GPU) - Fast iteration needed - Multiple tasks (train separate LoRAs, swap as needed) **Configuration:** ```python from peft import LoraConfig, get_peft_model config = LoraConfig( r=8, # Rank (4-16 typical, higher = more capacity) lora_alpha=32, # Scaling (usually 2× rank) target_modules=["q_proj", "v_proj"], # Which layers lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(base_model, config) print(model.print_trainable_parameters()) # trainable params: 8.4M || all params: 7B || trainable%: 0.12% ``` **Rank selection:** ``` r=4: Minimal (fast, low capacity) - simple tasks r=8: Standard (balanced) - most tasks r=16: High capacity (slower, better quality) - complex tasks r=32+: Approaching full FT quality (diminishing returns) Start with r=8, increase only if quality insufficient ``` ### 3. QLoRA (Quantized LoRA) **LoRA + 4-bit quantization of base model.** **Pros:** - Extremely memory efficient (4× less than LoRA) - 7B model fits on 16GB GPU - Same quality as LoRA **Cons:** - Slower than LoRA (quantization overhead) - More complex setup **When to use:** - Limited GPU memory (< 24GB) - Large models on consumer GPUs - Cost optimization (cheaper GPUs) **Setup:** ```python from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", quantization_config=bnb_config, device_map="auto" ) # Then add LoRA as usual model = get_peft_model(model, lora_config) ``` **Memory comparison:** ``` Method | 7B Model | 13B Model | 70B Model ---------------|----------|-----------|---------- Full FT | 120 GB | 200 GB | 1000 GB LoRA | 40 GB | 60 GB | 300 GB QLoRA | 12 GB | 20 GB | 80 GB ``` ### Method Selection: ```python if gpu_memory < 24: use_qlora() elif gpu_memory < 80: use_lora() elif have_massive_data and multi_gpu_cluster: use_full_finetuning() else: use_lora() # Default choice ``` ## Dataset Preparation **Quality > Quantity. 1,000 clean examples > 10,000 noisy examples.** ### 1. Data Collection **Good sources:** - Human-labeled data (gold standard) - Curated conversations (high-quality) - Expert-written examples - Validated user interactions **Bad sources:** - Raw logs (errors, incomplete, noise) - Scraped data (quality varies wildly) - Automated generation (may have artifacts) - Untested user inputs (edge cases, adversarial) ### 2. Data Cleaning ```python def clean_dataset(raw_data): clean = [] for example in raw_data: # Filter 1: Remove errors if any(err in example for err in ['error', 'exception', 'failed']): continue # Filter 2: Length checks if len(example['input']) < 10 or len(example['output']) < 10: continue # Too short if len(example['input']) > 2000 or len(example['output']) > 2000: continue # Too long (may be malformed) # Filter 3: Completeness if not example['output'].strip().endswith(('.', '!', '?')): continue # Incomplete response # Filter 4: Language check if not is_valid_language(example['output']): continue # Gibberish or wrong language # Filter 5: Duplicates if is_duplicate(example, clean): continue clean.append(example) return clean cleaned = clean_dataset(raw_data) print(f"Filtered: {len(raw_data)} → {len(cleaned)}") # Example: 10,000 → 3,000 (but high quality!) ``` ### 3. Manual Validation **Critical step: Spot check 100+ random examples.** ```python import random sample = random.sample(cleaned, min(100, len(cleaned))) for i, ex in enumerate(sample): print(f"\n--- Example {i+1}/100 ---") print(f"Input: {ex['input']}") print(f"Output: {ex['output']}") response = input("Quality (good/bad/skip)? ") if response == 'bad': # Investigate pattern, add filtering rule print("Why bad?") reason = input() # Update filtering logic ``` **What to check:** - ☐ Output is correct and complete - ☐ Output matches desired format/style - ☐ No errors or hallucinations - ☐ Appropriate length - ☐ Natural language (not robotic) - ☐ Consistent with other examples ### 4. Dataset Format **OpenAI format (for GPT fine-tuning):** ```json { "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is Paris."} ] } ``` **Hugging Face format:** ```python from datasets import Dataset data = { 'input': ["question 1", "question 2", ...], 'output': ["answer 1", "answer 2", ...] } dataset = Dataset.from_dict(data) ``` ### 5. Train/Val/Test Split ```python from sklearn.model_selection import train_test_split # 70% train, 15% val, 15% test train, temp = train_test_split(data, test_size=0.3, random_state=42) val, test = train_test_split(temp, test_size=0.5, random_state=42) print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}") # Example: Train: 2100, Val: 450, Test: 450 # Stratified split for imbalanced data train, temp = train_test_split( data, test_size=0.3, stratify=data['label'], random_state=42 ) ``` **Split guidelines:** - Minimum validation: 100 examples - Minimum test: 100 examples - Large datasets (> 10k): 80/10/10 split - Small datasets (< 5k): 70/15/15 split ### 6. Data Augmentation (Optional) **When you need more data:** ```python # Paraphrasing "What's the weather?" → "How's the weather today?" # Back-translation English → French → English (introduces variation) # Synthetic generation (use carefully!) few_shot_examples = [...] new_examples = llm.generate( f"Generate 10 examples similar to: {few_shot_examples}" ) # ALWAYS manually validate synthetic data! ``` **Warning:** Synthetic data can introduce artifacts. Always validate! ## Hyperparameters ### Learning Rate **Most critical hyperparameter.** ```python # Pre-training LR: 1e-3 to 3e-4 # Fine-tuning LR: 100-1000× smaller! training_args = TrainingArguments( learning_rate=1e-5, # Start here for 7B models # Or even more conservative: learning_rate=1e-6, # For larger models or small datasets ) ``` **Guidelines:** ``` Model size | Pre-train LR | Fine-tune LR ---------------|--------------|------------- 1B params | 3e-4 | 3e-5 to 1e-5 7B params | 3e-4 | 1e-5 to 1e-6 13B params | 2e-4 | 5e-6 to 1e-6 70B+ params | 1e-4 | 1e-6 to 1e-7 Rule: Fine-tune LR ≈ Pre-train LR / 100 ``` **LR scheduling:** ```python from transformers import get_linear_schedule_with_warmup optimizer = AdamW(model.parameters(), lr=1e-5) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=100, # Gradual LR increase (10% of training) num_training_steps=total_steps ) ``` **Signs of wrong LR:** Too high (LR > 1e-4): - Training loss oscillates wildly - Model generates gibberish - Catastrophic forgetting (fails on general tasks) Too low (LR < 1e-7): - Training loss barely decreases - Model doesn't adapt to new data - Very slow convergence ### Epochs ```python training_args = TrainingArguments( num_train_epochs=3, # Standard: 3-5 epochs ) ``` **Guidelines:** ``` Dataset size | Epochs -------------|------- < 1k | 5-10 (more passes needed) 1k-5k | 3-5 (standard) 5k-10k | 2-3 > 10k | 1-2 (large dataset, fewer passes) Rule: Smaller dataset → more epochs (but watch for overfitting!) ``` **Too many epochs:** - Training loss → 0 but val loss increases (overfitting) - Model memorizes training data - Catastrophic forgetting **Too few epochs:** - Model hasn't fully adapted - Training and val loss still decreasing ### Batch Size ```python training_args = TrainingArguments( per_device_train_batch_size=8, # Depends on GPU memory gradient_accumulation_steps=4, # Effective batch = 8 × 4 = 32 ) ``` **Guidelines:** ``` GPU Memory | Batch Size (7B model) -----------|---------------------- 16 GB | 1-2 (use gradient accumulation!) 24 GB | 2-4 40 GB | 4-8 80 GB | 8-16 Effective batch size (with accumulation): 16-64 typical ``` **Gradient accumulation:** ```python # Simulate batch_size=32 with only 8 examples fitting in memory: per_device_train_batch_size=8 gradient_accumulation_steps=4 # Effective batch = 8 × 4 = 32 ``` ### Weight Decay ```python training_args = TrainingArguments( weight_decay=0.01, # L2 regularization (prevent overfitting) ) ``` **Guidelines:** - Standard: 0.01 - Strong regularization: 0.1 (small dataset, high overfitting risk) - Light regularization: 0.001 (large dataset) ### Warmup ```python training_args = TrainingArguments( warmup_steps=100, # Or warmup_ratio=0.1 (10% of training) ) ``` **Why warmup:** - Prevents initial instability (large gradients early) - Gradual LR increase: 0 → target_LR over warmup steps **Guidelines:** - Warmup: 5-10% of total training steps - Longer warmup for larger models ## Training ### Basic Training Loop ```python from transformers import Trainer, TrainingArguments training_args = TrainingArguments( output_dir="./results", # Hyperparameters learning_rate=1e-5, num_train_epochs=3, per_device_train_batch_size=8, gradient_accumulation_steps=4, weight_decay=0.01, warmup_steps=100, # Evaluation evaluation_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=100, load_best_model_at_end=True, metric_for_best_model="eval_loss", # Logging logging_steps=10, logging_dir="./logs", # Optimization fp16=True, # Mixed precision (faster, less memory) gradient_checkpointing=True, # Trade compute for memory ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, tokenizer=tokenizer, ) trainer.train() ``` ### Monitoring Training **Key metrics to watch:** ```python # 1. Training loss (should decrease steadily) # 2. Validation loss (should decrease, then plateau) # 3. Validation metrics (accuracy, F1, BLEU, etc.) # Warning signs: # - Train loss → 0 but val loss increasing: Overfitting # - Train loss oscillating: LR too high # - Train loss not decreasing: LR too low or data issues ``` **Logging:** ```python import wandb wandb.init(project="fine-tuning") training_args = TrainingArguments( report_to="wandb", # Log to Weights & Biases logging_steps=10, ) ``` ### Early Stopping ```python from transformers import EarlyStoppingCallback trainer = Trainer( ... callbacks=[EarlyStoppingCallback( early_stopping_patience=3, # Stop if no improvement for 3 evals early_stopping_threshold=0.01, # Minimum improvement )] ) ``` **Why early stopping:** - Prevents overfitting (stops before val loss increases) - Saves compute (don't train unnecessary epochs) - Automatically finds optimal epoch count ## Evaluation ### 1. Validation During Training ```python def compute_metrics(eval_pred): predictions, labels = eval_pred # Decode predictions decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Compute metrics from sklearn.metrics import accuracy_score, f1_score accuracy = accuracy_score(decoded_labels, decoded_preds) f1 = f1_score(decoded_labels, decoded_preds, average='weighted') return {'accuracy': accuracy, 'f1': f1} trainer = Trainer( ... compute_metrics=compute_metrics, ) ``` ### 2. Test Set Evaluation (Final) ```python # After training completes, evaluate on held-out test set ONCE test_results = trainer.evaluate(test_dataset) print(f"Test accuracy: {test_results['accuracy']:.2%}") print(f"Test F1: {test_results['f1']:.2%}") ``` ### 3. Qualitative Evaluation **Critical: Manually test on real examples!** ```python def test_model(model, tokenizer, test_examples): for ex in test_examples: prompt = ex['input'] expected = ex['output'] # Generate inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate(**inputs, max_length=100) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"Input: {prompt}") print(f"Expected: {expected}") print(f"Generated: {generated}") print(f"Match: {'✓' if generated == expected else '✗'}") print("-" * 80) # Test on 20-50 examples (including edge cases) test_model(model, tokenizer, test_examples) ``` ### 4. A/B Testing (Production) ```python # Route 50% traffic to base model, 50% to fine-tuned import random def get_model(): if random.random() < 0.5: return base_model else: return finetuned_model # Measure: # - User satisfaction (thumbs up/down) # - Task success rate # - Response time # - Cost per request # After 1000+ requests, analyze results ``` ### 5. Catastrophic Forgetting Check **Critical: Ensure fine-tuning didn't break base capabilities!** ```python # Test on general knowledge tasks general_tasks = [ "What is the capital of France?", # Basic knowledge "Translate to Spanish: Hello", # Translation "2 + 2 = ?", # Basic math "Who wrote Hamlet?", # Literature ] for task in general_tasks: before = base_model.generate(task) after = finetuned_model.generate(task) print(f"Task: {task}") print(f"Before: {before}") print(f"After: {after}") print(f"Preserved: {'✓' if before == after else '✗'}") ``` ## Common Issues and Solutions ### Issue 1: Overfitting **Symptoms:** - Train loss → 0, val loss increases - Perfect on training data, poor on test data **Solutions:** ```python # 1. Reduce epochs num_train_epochs=3 # Instead of 10 # 2. Increase regularization weight_decay=0.1 # Instead of 0.01 # 3. Early stopping early_stopping_patience=3 # 4. Collect more data # 5. Data augmentation # 6. Use LoRA (less prone to overfitting than full FT) ``` ### Issue 2: Catastrophic Forgetting **Symptoms:** - Fine-tuned model fails on general tasks - Lost pre-trained knowledge **Solutions:** ```python # 1. Lower learning rate (most important!) learning_rate=1e-6 # Instead of 1e-4 # 2. Fewer epochs num_train_epochs=2 # Instead of 10 # 3. Use LoRA (base model frozen, can't forget) # 4. Add general examples to training set (10-20% general data) ``` ### Issue 3: Poor Quality **Symptoms:** - Model output is low quality (incorrect, incoherent) **Solutions:** ```python # 1. Check dataset quality (most common cause!) # - Manual validation # - Remove noise # - Fix labels # 2. Increase model size # - 7B → 13B → 70B # 3. Increase training data # - Need 1000+ high-quality examples # 4. Adjust hyperparameters # - Try higher LR (1e-5 → 3e-5) if underfit # - Train longer (3 → 5 epochs) # 5. Check if base model has capability # - If base model can't do task, fine-tuning won't help ``` ### Issue 4: Slow Training **Symptoms:** - Training takes days/weeks **Solutions:** ```python # 1. Use LoRA (10× faster than full FT) # 2. Mixed precision fp16=True # 2× faster # 3. Gradient checkpointing (trade speed for memory) gradient_checkpointing=True # 4. Smaller batch size + gradient accumulation per_device_train_batch_size=2 gradient_accumulation_steps=16 # 5. Use multiple GPUs # 6. Use faster GPU (A100 > V100 > T4) ``` ### Issue 5: Out of Memory **Symptoms:** - CUDA out of memory error **Solutions:** ```python # 1. Use QLoRA (4× less memory) # 2. Reduce batch size per_device_train_batch_size=1 gradient_accumulation_steps=32 # 3. Gradient checkpointing gradient_checkpointing=True # 4. Use smaller model # 7B → 3B → 1B # 5. Reduce sequence length max_seq_length=512 # Instead of 2048 ``` ## Best Practices Summary ### Before Fine-Tuning: 1. ☐ Try prompt engineering first (90% of cases, prompts work!) 2. ☐ Have 1000+ high-quality examples 3. ☐ Clean and validate dataset (quality > quantity) 4. ☐ Create train/val/test split (70/15/15) 5. ☐ Define success metrics (what does "good" mean?) ### During Fine-Tuning: 6. ☐ Use LoRA (unless specific reason for full FT) 7. ☐ Set tiny learning rate (1e-5 to 1e-6 for 7B models) 8. ☐ Train for 3-5 epochs (not 50!) 9. ☐ Monitor val loss (stop when it stops improving) 10. ☐ Log everything (wandb, tensorboard) ### After Fine-Tuning: 11. ☐ Evaluate on test set (quantitative metrics) 12. ☐ Manual testing (qualitative, 20-50 examples) 13. ☐ Check for catastrophic forgetting (general tasks) 14. ☐ A/B test in production (before full rollout) 15. ☐ Document hyperparameters (for reproducibility) ## Quick Reference | Task | Method | Dataset | LR | Epochs | |------|--------|---------|----|----| | Tone matching | Prompts | N/A | N/A | N/A | | Simple classification | Prompts | N/A | N/A | N/A | | Complex domain task | LoRA | 1k-10k | 1e-5 | 3-5 | | Fundamental change | Full FT | 100k+ | 1e-5 | 1-3 | | Limited GPU | QLoRA | 1k-10k | 1e-5 | 3-5 | **Default recommendation:** Try prompts first. If that fails, use LoRA with LR=1e-5, epochs=3, and high-quality dataset. ## Summary **Core principles:** 1. **Prompt engineering first**: 90% of tasks don't need fine-tuning 2. **LoRA by default**: 100× more efficient than full fine-tuning, same quality 3. **Data quality matters**: 1,000 clean examples > 10,000 noisy examples 4. **Tiny learning rate**: Fine-tune LR = Pre-train LR / 100 to / 1000 5. **Validation essential**: Train/val/test split + early stopping + catastrophic forgetting check **Decision tree:** 1. Try prompts (system message + few-shot) 2. If quality < 90%, optimize prompts 3. If still < 90% and have 1000+ examples, consider fine-tuning 4. Use LoRA (default), QLoRA (limited GPU), or full FT (rare) 5. Set LR = 1e-5, epochs = 3-5, monitor val loss 6. Evaluate on test set + manual testing + general tasks **Key insight**: Fine-tuning is powerful but expensive and slow. Start with prompts, fine-tune only when prompts demonstrably fail and you have high-quality data.