713 lines
19 KiB
Markdown
713 lines
19 KiB
Markdown
---
|
|
name: model-selection
|
|
description: Automatically applies when choosing LLM models and providers. Ensures proper model comparison, provider selection, cost optimization, fallback patterns, and multi-model strategies.
|
|
category: ai-llm
|
|
---
|
|
|
|
# Model Selection and Provider Management
|
|
|
|
When selecting LLM models and managing providers, follow these patterns for optimal cost, performance, and reliability.
|
|
|
|
**Trigger Keywords**: model selection, provider, model comparison, fallback, OpenAI, Anthropic, model routing, cost optimization, model capabilities, provider failover
|
|
|
|
**Agent Integration**: Used by `ml-system-architect`, `performance-and-cost-engineer-llm`, `llm-app-engineer`
|
|
|
|
## ✅ Correct Pattern: Model Registry
|
|
|
|
```python
|
|
from typing import Optional, Dict, List
|
|
from pydantic import BaseModel, Field
|
|
from enum import Enum
|
|
|
|
|
|
class ModelProvider(str, Enum):
|
|
"""Supported LLM providers."""
|
|
ANTHROPIC = "anthropic"
|
|
OPENAI = "openai"
|
|
GOOGLE = "google"
|
|
LOCAL = "local"
|
|
|
|
|
|
class ModelCapabilities(BaseModel):
|
|
"""Model capabilities and constraints."""
|
|
|
|
max_context_tokens: int
|
|
max_output_tokens: int
|
|
supports_streaming: bool = True
|
|
supports_function_calling: bool = False
|
|
supports_vision: bool = False
|
|
supports_json_mode: bool = False
|
|
|
|
|
|
class ModelPricing(BaseModel):
|
|
"""Model pricing information."""
|
|
|
|
input_price_per_mtok: float # USD per million tokens
|
|
output_price_per_mtok: float
|
|
cache_write_price_per_mtok: Optional[float] = None
|
|
cache_read_price_per_mtok: Optional[float] = None
|
|
|
|
|
|
class ModelConfig(BaseModel):
|
|
"""Complete model configuration."""
|
|
|
|
id: str
|
|
name: str
|
|
provider: ModelProvider
|
|
capabilities: ModelCapabilities
|
|
pricing: ModelPricing
|
|
recommended_use_cases: List[str] = Field(default_factory=list)
|
|
quality_tier: str # "flagship", "balanced", "fast"
|
|
|
|
|
|
class ModelRegistry:
|
|
"""Registry of available models with metadata."""
|
|
|
|
def __init__(self):
|
|
self.models: Dict[str, ModelConfig] = {}
|
|
self._register_default_models()
|
|
|
|
def _register_default_models(self):
|
|
"""Register commonly used models."""
|
|
|
|
# Claude models
|
|
self.register(ModelConfig(
|
|
id="claude-sonnet-4-20250514",
|
|
name="Claude Sonnet 4",
|
|
provider=ModelProvider.ANTHROPIC,
|
|
capabilities=ModelCapabilities(
|
|
max_context_tokens=200_000,
|
|
max_output_tokens=8_192,
|
|
supports_streaming=True,
|
|
supports_vision=True,
|
|
supports_json_mode=True
|
|
),
|
|
pricing=ModelPricing(
|
|
input_price_per_mtok=3.00,
|
|
output_price_per_mtok=15.00,
|
|
cache_write_price_per_mtok=3.75,
|
|
cache_read_price_per_mtok=0.30
|
|
),
|
|
recommended_use_cases=[
|
|
"complex reasoning",
|
|
"long context",
|
|
"code generation"
|
|
],
|
|
quality_tier="flagship"
|
|
))
|
|
|
|
self.register(ModelConfig(
|
|
id="claude-haiku-3-5-20250514",
|
|
name="Claude Haiku 3.5",
|
|
provider=ModelProvider.ANTHROPIC,
|
|
capabilities=ModelCapabilities(
|
|
max_context_tokens=200_000,
|
|
max_output_tokens=8_192,
|
|
supports_streaming=True,
|
|
supports_vision=True
|
|
),
|
|
pricing=ModelPricing(
|
|
input_price_per_mtok=0.80,
|
|
output_price_per_mtok=4.00
|
|
),
|
|
recommended_use_cases=[
|
|
"high throughput",
|
|
"cost-sensitive",
|
|
"simple tasks"
|
|
],
|
|
quality_tier="fast"
|
|
))
|
|
|
|
# OpenAI models
|
|
self.register(ModelConfig(
|
|
id="gpt-4-turbo",
|
|
name="GPT-4 Turbo",
|
|
provider=ModelProvider.OPENAI,
|
|
capabilities=ModelCapabilities(
|
|
max_context_tokens=128_000,
|
|
max_output_tokens=4_096,
|
|
supports_streaming=True,
|
|
supports_function_calling=True,
|
|
supports_vision=True,
|
|
supports_json_mode=True
|
|
),
|
|
pricing=ModelPricing(
|
|
input_price_per_mtok=10.00,
|
|
output_price_per_mtok=30.00
|
|
),
|
|
recommended_use_cases=[
|
|
"function calling",
|
|
"complex reasoning",
|
|
"structured output"
|
|
],
|
|
quality_tier="flagship"
|
|
))
|
|
|
|
def register(self, model: ModelConfig):
|
|
"""Register a model."""
|
|
self.models[model.id] = model
|
|
|
|
def get(self, model_id: str) -> Optional[ModelConfig]:
|
|
"""Get model by ID."""
|
|
return self.models.get(model_id)
|
|
|
|
def find_by_criteria(
|
|
self,
|
|
max_cost_per_mtok: Optional[float] = None,
|
|
min_context_tokens: Optional[int] = None,
|
|
requires_streaming: bool = False,
|
|
requires_vision: bool = False,
|
|
quality_tier: Optional[str] = None,
|
|
provider: Optional[ModelProvider] = None
|
|
) -> List[ModelConfig]:
|
|
"""
|
|
Find models matching criteria.
|
|
|
|
Args:
|
|
max_cost_per_mtok: Maximum input cost
|
|
min_context_tokens: Minimum context window
|
|
requires_streaming: Must support streaming
|
|
requires_vision: Must support vision
|
|
quality_tier: Quality tier filter
|
|
provider: Provider filter
|
|
|
|
Returns:
|
|
List of matching models
|
|
"""
|
|
matches = []
|
|
|
|
for model in self.models.values():
|
|
# Check cost
|
|
if max_cost_per_mtok is not None:
|
|
if model.pricing.input_price_per_mtok > max_cost_per_mtok:
|
|
continue
|
|
|
|
# Check context
|
|
if min_context_tokens is not None:
|
|
if model.capabilities.max_context_tokens < min_context_tokens:
|
|
continue
|
|
|
|
# Check capabilities
|
|
if requires_streaming and not model.capabilities.supports_streaming:
|
|
continue
|
|
if requires_vision and not model.capabilities.supports_vision:
|
|
continue
|
|
|
|
# Check tier
|
|
if quality_tier and model.quality_tier != quality_tier:
|
|
continue
|
|
|
|
# Check provider
|
|
if provider and model.provider != provider:
|
|
continue
|
|
|
|
matches.append(model)
|
|
|
|
# Sort by cost (cheapest first)
|
|
matches.sort(key=lambda m: m.pricing.input_price_per_mtok)
|
|
|
|
return matches
|
|
|
|
|
|
# Usage
|
|
registry = ModelRegistry()
|
|
|
|
# Find cost-effective models with vision
|
|
models = registry.find_by_criteria(
|
|
max_cost_per_mtok=5.00,
|
|
requires_vision=True
|
|
)
|
|
|
|
for model in models:
|
|
print(f"{model.name}: ${model.pricing.input_price_per_mtok}/MTok")
|
|
```
|
|
|
|
## Model Router
|
|
|
|
```python
|
|
import asyncio
|
|
from typing import Callable, Optional
|
|
|
|
|
|
class ModelRouter:
|
|
"""Route requests to appropriate models based on task."""
|
|
|
|
def __init__(self, registry: ModelRegistry):
|
|
self.registry = registry
|
|
self.routing_rules = []
|
|
|
|
def add_rule(
|
|
self,
|
|
name: str,
|
|
condition: Callable[[str], bool],
|
|
model_id: str,
|
|
priority: int = 0
|
|
):
|
|
"""
|
|
Add routing rule.
|
|
|
|
Args:
|
|
name: Rule name
|
|
condition: Function that checks if rule applies
|
|
model_id: Model to route to
|
|
priority: Higher priority rules checked first
|
|
"""
|
|
self.routing_rules.append({
|
|
"name": name,
|
|
"condition": condition,
|
|
"model_id": model_id,
|
|
"priority": priority
|
|
})
|
|
|
|
# Sort by priority
|
|
self.routing_rules.sort(key=lambda r: r["priority"], reverse=True)
|
|
|
|
def route(self, prompt: str) -> str:
|
|
"""
|
|
Determine which model to use for prompt.
|
|
|
|
Args:
|
|
prompt: Input prompt
|
|
|
|
Returns:
|
|
Model ID to use
|
|
"""
|
|
for rule in self.routing_rules:
|
|
if rule["condition"](prompt):
|
|
return rule["model_id"]
|
|
|
|
# Default fallback
|
|
return "claude-sonnet-4-20250514"
|
|
|
|
|
|
# Example routing rules
|
|
router = ModelRouter(registry)
|
|
|
|
# Route simple tasks to fast model
|
|
router.add_rule(
|
|
name="simple_tasks",
|
|
condition=lambda p: len(p) < 100 and "?" in p,
|
|
model_id="claude-haiku-3-5-20250514",
|
|
priority=1
|
|
)
|
|
|
|
# Route code to Sonnet
|
|
router.add_rule(
|
|
name="code_tasks",
|
|
condition=lambda p: any(kw in p.lower() for kw in ["code", "function", "class"]),
|
|
model_id="claude-sonnet-4-20250514",
|
|
priority=2
|
|
)
|
|
|
|
# Route long context to Claude
|
|
router.add_rule(
|
|
name="long_context",
|
|
condition=lambda p: len(p) > 50_000,
|
|
model_id="claude-sonnet-4-20250514",
|
|
priority=3
|
|
)
|
|
|
|
# Use router
|
|
prompt = "Write a Python function to sort a list"
|
|
model_id = router.route(prompt)
|
|
print(f"Using model: {model_id}")
|
|
```
|
|
|
|
## Fallback Chain
|
|
|
|
```python
|
|
from typing import List, Optional
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FallbackChain:
|
|
"""Implement fallback chain for reliability."""
|
|
|
|
def __init__(
|
|
self,
|
|
primary_model: str,
|
|
fallback_models: List[str],
|
|
registry: ModelRegistry
|
|
):
|
|
self.primary_model = primary_model
|
|
self.fallback_models = fallback_models
|
|
self.registry = registry
|
|
|
|
async def complete_with_fallback(
|
|
self,
|
|
prompt: str,
|
|
clients: Dict[str, any],
|
|
**kwargs
|
|
) -> Dict[str, any]:
|
|
"""
|
|
Try primary model, fallback on failure.
|
|
|
|
Args:
|
|
prompt: Input prompt
|
|
clients: Dict mapping provider to client
|
|
**kwargs: Additional model parameters
|
|
|
|
Returns:
|
|
Dict with response and metadata
|
|
"""
|
|
models_to_try = [self.primary_model] + self.fallback_models
|
|
|
|
last_error = None
|
|
|
|
for model_id in models_to_try:
|
|
model_config = self.registry.get(model_id)
|
|
if not model_config:
|
|
logger.warning(f"Model {model_id} not in registry")
|
|
continue
|
|
|
|
try:
|
|
logger.info(f"Attempting request with {model_id}")
|
|
|
|
# Get client for provider
|
|
client = clients.get(model_config.provider.value)
|
|
if not client:
|
|
logger.warning(f"No client for {model_config.provider}")
|
|
continue
|
|
|
|
# Make request
|
|
response = await client.complete(
|
|
prompt,
|
|
model=model_id,
|
|
**kwargs
|
|
)
|
|
|
|
return {
|
|
"response": response,
|
|
"model_used": model_id,
|
|
"fallback_occurred": model_id != self.primary_model
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Request failed for {model_id}: {e}")
|
|
last_error = e
|
|
continue
|
|
|
|
# All models failed
|
|
raise Exception(f"All models failed. Last error: {last_error}")
|
|
|
|
|
|
# Usage
|
|
fallback_chain = FallbackChain(
|
|
primary_model="claude-sonnet-4-20250514",
|
|
fallback_models=[
|
|
"gpt-4-turbo",
|
|
"claude-haiku-3-5-20250514"
|
|
],
|
|
registry=registry
|
|
)
|
|
|
|
result = await fallback_chain.complete_with_fallback(
|
|
prompt="What is Python?",
|
|
clients={
|
|
"anthropic": anthropic_client,
|
|
"openai": openai_client
|
|
}
|
|
)
|
|
|
|
print(f"Response from: {result['model_used']}")
|
|
```
|
|
|
|
## Cost Optimization
|
|
|
|
```python
|
|
class CostOptimizer:
|
|
"""Optimize costs by selecting appropriate models."""
|
|
|
|
def __init__(self, registry: ModelRegistry):
|
|
self.registry = registry
|
|
|
|
def estimate_cost(
|
|
self,
|
|
model_id: str,
|
|
input_tokens: int,
|
|
output_tokens: int
|
|
) -> float:
|
|
"""
|
|
Estimate cost for request.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
input_tokens: Input token count
|
|
output_tokens: Expected output tokens
|
|
|
|
Returns:
|
|
Estimated cost in USD
|
|
"""
|
|
model = self.registry.get(model_id)
|
|
if not model:
|
|
raise ValueError(f"Unknown model: {model_id}")
|
|
|
|
input_cost = (
|
|
input_tokens / 1_000_000
|
|
) * model.pricing.input_price_per_mtok
|
|
|
|
output_cost = (
|
|
output_tokens / 1_000_000
|
|
) * model.pricing.output_price_per_mtok
|
|
|
|
return input_cost + output_cost
|
|
|
|
def find_cheapest_model(
|
|
self,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
quality_tier: Optional[str] = None,
|
|
**criteria
|
|
) -> tuple[str, float]:
|
|
"""
|
|
Find cheapest model meeting criteria.
|
|
|
|
Args:
|
|
input_tokens: Input token count
|
|
output_tokens: Expected output tokens
|
|
quality_tier: Optional quality requirement
|
|
**criteria: Additional model criteria
|
|
|
|
Returns:
|
|
Tuple of (model_id, estimated_cost)
|
|
"""
|
|
# Find matching models
|
|
candidates = self.registry.find_by_criteria(
|
|
quality_tier=quality_tier,
|
|
**criteria
|
|
)
|
|
|
|
if not candidates:
|
|
raise ValueError("No models match criteria")
|
|
|
|
# Calculate costs
|
|
costs = [
|
|
(
|
|
model.id,
|
|
self.estimate_cost(model.id, input_tokens, output_tokens)
|
|
)
|
|
for model in candidates
|
|
]
|
|
|
|
# Return cheapest
|
|
return min(costs, key=lambda x: x[1])
|
|
|
|
def batch_cost_analysis(
|
|
self,
|
|
requests: List[Dict[str, int]],
|
|
model_ids: List[str]
|
|
) -> Dict[str, Dict[str, float]]:
|
|
"""
|
|
Analyze costs for batch of requests across models.
|
|
|
|
Args:
|
|
requests: List of dicts with input_tokens, output_tokens
|
|
model_ids: Models to compare
|
|
|
|
Returns:
|
|
Cost breakdown per model
|
|
"""
|
|
analysis = {}
|
|
|
|
for model_id in model_ids:
|
|
total_cost = 0.0
|
|
total_tokens = 0
|
|
|
|
for req in requests:
|
|
cost = self.estimate_cost(
|
|
model_id,
|
|
req["input_tokens"],
|
|
req["output_tokens"]
|
|
)
|
|
total_cost += cost
|
|
total_tokens += req["input_tokens"] + req["output_tokens"]
|
|
|
|
analysis[model_id] = {
|
|
"total_cost_usd": total_cost,
|
|
"avg_cost_per_request": total_cost / len(requests),
|
|
"total_tokens": total_tokens,
|
|
"cost_per_1k_tokens": (total_cost / total_tokens) * 1000
|
|
}
|
|
|
|
return analysis
|
|
|
|
|
|
# Usage
|
|
optimizer = CostOptimizer(registry)
|
|
|
|
# Find cheapest model for task
|
|
model_id, cost = optimizer.find_cheapest_model(
|
|
input_tokens=1000,
|
|
output_tokens=500,
|
|
quality_tier="balanced"
|
|
)
|
|
|
|
print(f"Cheapest model: {model_id} (${cost:.4f})")
|
|
|
|
# Compare costs across models
|
|
requests = [
|
|
{"input_tokens": 1000, "output_tokens": 500},
|
|
{"input_tokens": 2000, "output_tokens": 1000},
|
|
]
|
|
|
|
cost_analysis = optimizer.batch_cost_analysis(
|
|
requests,
|
|
model_ids=[
|
|
"claude-sonnet-4-20250514",
|
|
"claude-haiku-3-5-20250514",
|
|
"gpt-4-turbo"
|
|
]
|
|
)
|
|
|
|
for model_id, stats in cost_analysis.items():
|
|
print(f"{model_id}: ${stats['total_cost_usd']:.4f}")
|
|
```
|
|
|
|
## Multi-Model Ensemble
|
|
|
|
```python
|
|
class ModelEnsemble:
|
|
"""Ensemble multiple models for improved results."""
|
|
|
|
def __init__(
|
|
self,
|
|
models: List[str],
|
|
voting_strategy: str = "majority"
|
|
):
|
|
self.models = models
|
|
self.voting_strategy = voting_strategy
|
|
|
|
async def complete_ensemble(
|
|
self,
|
|
prompt: str,
|
|
clients: Dict[str, any],
|
|
registry: ModelRegistry
|
|
) -> str:
|
|
"""
|
|
Get predictions from multiple models and combine.
|
|
|
|
Args:
|
|
prompt: Input prompt
|
|
clients: Provider clients
|
|
registry: Model registry
|
|
|
|
Returns:
|
|
Combined prediction
|
|
"""
|
|
# Get predictions from all models
|
|
tasks = []
|
|
for model_id in self.models:
|
|
model_config = registry.get(model_id)
|
|
client = clients.get(model_config.provider.value)
|
|
|
|
task = client.complete(prompt, model=model_id)
|
|
tasks.append(task)
|
|
|
|
predictions = await asyncio.gather(*tasks)
|
|
|
|
# Combine predictions
|
|
if self.voting_strategy == "majority":
|
|
return self._majority_vote(predictions)
|
|
elif self.voting_strategy == "longest":
|
|
return max(predictions, key=len)
|
|
elif self.voting_strategy == "first":
|
|
return predictions[0]
|
|
else:
|
|
raise ValueError(f"Unknown strategy: {self.voting_strategy}")
|
|
|
|
def _majority_vote(self, predictions: List[str]) -> str:
|
|
"""Select most common prediction."""
|
|
from collections import Counter
|
|
counter = Counter(predictions)
|
|
return counter.most_common(1)[0][0]
|
|
|
|
|
|
# Usage
|
|
ensemble = ModelEnsemble(
|
|
models=[
|
|
"claude-sonnet-4-20250514",
|
|
"gpt-4-turbo",
|
|
"claude-opus-4-20250514"
|
|
],
|
|
voting_strategy="majority"
|
|
)
|
|
|
|
result = await ensemble.complete_ensemble(
|
|
prompt="Is Python case-sensitive? Answer yes or no.",
|
|
clients=clients,
|
|
registry=registry
|
|
)
|
|
```
|
|
|
|
## ❌ Anti-Patterns
|
|
|
|
```python
|
|
# ❌ Hardcoded model ID everywhere
|
|
response = client.complete("prompt", model="claude-sonnet-4-20250514")
|
|
|
|
# ✅ Better: Use model registry and routing
|
|
model_id = router.route(prompt)
|
|
response = client.complete(prompt, model=model_id)
|
|
|
|
|
|
# ❌ No fallback on failure
|
|
try:
|
|
response = await primary_client.complete(prompt)
|
|
except:
|
|
raise # App crashes!
|
|
|
|
# ✅ Better: Implement fallback chain
|
|
result = await fallback_chain.complete_with_fallback(prompt, clients)
|
|
|
|
|
|
# ❌ Always using most expensive model
|
|
model = "claude-opus-4-20250514" # Always flagship!
|
|
|
|
# ✅ Better: Route based on task complexity
|
|
model_id = router.route(prompt) # Uses appropriate model
|
|
|
|
|
|
# ❌ No cost tracking
|
|
response = await client.complete(prompt) # No idea of cost!
|
|
|
|
# ✅ Better: Track and optimize costs
|
|
cost = optimizer.estimate_cost(model_id, input_tokens, output_tokens)
|
|
logger.info(f"Request cost: ${cost:.4f}")
|
|
```
|
|
|
|
## Best Practices Checklist
|
|
|
|
- ✅ Use model registry to centralize model metadata
|
|
- ✅ Implement routing to select appropriate models
|
|
- ✅ Set up fallback chains for reliability
|
|
- ✅ Track and optimize costs per model
|
|
- ✅ Consider quality tier for each use case
|
|
- ✅ Monitor model performance metrics
|
|
- ✅ Use cheaper models for simple tasks
|
|
- ✅ Cache model responses when appropriate
|
|
- ✅ Test fallback chains regularly
|
|
- ✅ Document model selection rationale
|
|
- ✅ Review costs regularly and optimize
|
|
- ✅ Keep model metadata up to date
|
|
|
|
## Auto-Apply
|
|
|
|
When selecting models:
|
|
1. Register models in ModelRegistry with capabilities and pricing
|
|
2. Implement ModelRouter for task-based routing
|
|
3. Set up FallbackChain for reliability
|
|
4. Use CostOptimizer to find cost-effective options
|
|
5. Track costs per model and request
|
|
6. Document routing logic and fallback strategy
|
|
7. Monitor model performance and costs
|
|
|
|
## Related Skills
|
|
|
|
- `llm-app-architecture` - For LLM integration
|
|
- `evaluation-metrics` - For model comparison
|
|
- `monitoring-alerting` - For tracking performance
|
|
- `performance-profiling` - For optimization
|
|
- `prompting-patterns` - For model-specific prompts
|