""" LiteLLM client wrapper with token counting and error handling """ import os from pathlib import Path from typing import Any import requests from litellm import ( completion_cost, get_max_tokens, token_counter, validate_environment, ) from litellm.utils import get_model_info import config from response_strategy import ResponseStrategyFactory class LiteLLMClient: """Wrapper around LiteLLM with enhanced functionality""" def __init__(self, base_url: str | None = None, api_key: str | None = None) -> None: self.base_url = base_url self.api_key = api_key or config.get_api_key() # Configure litellm if self.api_key: # Set API key in environment for litellm to pick up if not os.environ.get("OPENAI_API_KEY"): os.environ["OPENAI_API_KEY"] = self.api_key def complete( self, model: str, prompt: str, session_dir: Path | None = None, reasoning_effort: str = "high", multimodal_content: list[dict[str, Any]] | None = None, **kwargs: Any, ) -> dict[str, Any]: """ Make a request using the responses API with automatic retry/background job handling. Uses strategy pattern to: - Use background jobs for OpenAI/Azure (resumable after network failures) - Use sync with retries for other providers Args: model: Model identifier prompt: Full prompt text session_dir: Optional session directory for state persistence (enables resumability) reasoning_effort: Reasoning effort level (low, medium, high) - default high multimodal_content: Optional multimodal content array for images **kwargs: Additional args passed to litellm.responses() Returns: Dict with 'content' and optional 'usage' """ # Add base_url if configured if self.base_url: kwargs["api_base"] = self.base_url # Add reasoning_effort parameter kwargs["reasoning_effort"] = reasoning_effort # Select appropriate strategy based on model strategy = ResponseStrategyFactory.get_strategy(model) if session_dir: api_type = ResponseStrategyFactory.get_api_type(model) print( f"Using {strategy.__class__.__name__} (resumable: {strategy.can_resume()})" ) print(f"API: {api_type} | Reasoning effort: {reasoning_effort}") try: # Execute with strategy-specific retry/background logic result: dict[str, Any] = strategy.execute( model=model, prompt=prompt, session_dir=session_dir, multimodal_content=multimodal_content, **kwargs, ) return result except Exception as e: # Map to standardized errors error_msg = str(e) if "context" in error_msg.lower() or "token" in error_msg.lower(): raise ValueError(f"Context limit exceeded: {error_msg}") from e elif "auth" in error_msg.lower() or "key" in error_msg.lower(): raise PermissionError(f"Authentication failed: {error_msg}") from e elif "not found" in error_msg.lower() or "404" in error_msg: raise ValueError(f"Model not found: {error_msg}") from e else: raise RuntimeError(f"LLM request failed: {error_msg}") from e def count_tokens(self, text: str, model: str) -> int: """ Count tokens for given text and model. When base_url is set (proxy mode), uses the proxy's /utils/token_counter endpoint for accurate tokenization of custom models. Otherwise uses local token_counter. """ # If using a proxy (base_url set), use the proxy's token counter endpoint if self.base_url: url = f"{self.base_url.rstrip('/')}/utils/token_counter" payload = {"model": model, "text": text} headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" response = requests.post(url, json=payload, headers=headers, timeout=30) response.raise_for_status() # Response typically has format: {"token_count": 123} result = response.json() token_count = result.get("token_count") or result.get("tokens") if token_count is None: raise RuntimeError( f"Proxy token counter returned invalid response: {result}" ) return int(token_count) # Use local token counter (direct API mode) return int(token_counter(model=model, text=text)) def get_max_tokens(self, model: str) -> int: """Get maximum context size for model""" try: return int(get_max_tokens(model)) except Exception as e: # Try get_model_info as alternative method try: info = get_model_info(model=model) max_tokens = info.get("max_tokens") if max_tokens is None: raise RuntimeError( f"Could not determine max_tokens for model {model}" ) return int(max_tokens) except Exception as inner_e: raise RuntimeError( f"Could not get max tokens for model {model}: {e}, {inner_e}" ) from inner_e def calculate_cost( self, model: str, response: Any = None, usage: dict[str, Any] | None = None, ) -> dict[str, Any] | None: """ Calculate cost using LiteLLM's built-in completion_cost() function. Args: model: Model identifier response: Optional response object from litellm.responses() usage: Optional usage dict (fallback if response not available) Returns: Dict with input_tokens, output_tokens, costs, or None if unavailable """ try: # Prefer using response object with built-in function if response: total_cost = completion_cost(completion_response=response) # Extract token counts from response.usage if available if hasattr(response, "usage"): usage = response.usage # Calculate from usage dict if provided if usage: input_tokens = usage.get("prompt_tokens") or usage.get( "input_tokens", 0 ) output_tokens = usage.get("completion_tokens") or usage.get( "output_tokens", 0 ) # Get per-token costs from model info info = get_model_info(model=model) input_cost_per_token = info.get("input_cost_per_token", 0) output_cost_per_token = info.get("output_cost_per_token", 0) input_cost = input_tokens * input_cost_per_token output_cost = output_tokens * output_cost_per_token # Use total_cost from completion_cost if available, else calculate if not response: total_cost = input_cost + output_cost return { "input_tokens": input_tokens, "output_tokens": output_tokens, "input_cost": input_cost, "output_cost": output_cost, "total_cost": total_cost, "currency": "USD", } return None except Exception: # If we can't get pricing info, return None return None def validate_environment(self, model: str) -> dict[str, Any]: """ Check if required environment variables are set for the model. Returns dict with 'keys_in_environment' (bool) and 'missing_keys' (list). """ try: result: dict[str, Any] = validate_environment(model=model) return result except Exception as e: # If validation fails, return a generic response return { "keys_in_environment": False, "missing_keys": ["API_KEY"], "error": str(e), } def test_connection(self, model: str) -> bool: """Test if we can connect to the model""" try: result = self.complete(model=model, prompt="Hello", max_tokens=5) return result.get("content") is not None except Exception: return False