Initial commit

This commit is contained in:
Zhongwei Li
2025-11-29 18:23:41 +08:00
commit 016e36f3f3
20 changed files with 4365 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
"""
Consultant Python Implementation
LiteLLM-based tool for flexible multi-provider LLM consultations
"""
__version__ = "1.0.0"

View File

@@ -0,0 +1,46 @@
"""
Configuration and constants for consultant Python implementation
"""
import os
from pathlib import Path
# Session storage location
DEFAULT_SESSIONS_DIR = Path.home() / ".consultant" / "sessions"
# Environment variable names
ENV_LITELLM_API_KEY = "LITELLM_API_KEY"
ENV_OPENAI_API_KEY = "OPENAI_API_KEY"
ENV_ANTHROPIC_API_KEY = "ANTHROPIC_API_KEY"
ENV_OPENAI_BASE_URL = "OPENAI_BASE_URL"
# Token budget: Reserve this percentage for response
CONTEXT_RESERVE_RATIO = 0.2 # 20% reserved for response
# Retry configuration
MAX_RETRIES = 3
INITIAL_RETRY_DELAY = 2 # seconds
MAX_RETRY_DELAY = 60 # seconds
# Background job polling configuration
POLL_INTERVAL = 20 # seconds between polls (configurable)
POLL_TIMEOUT = 3600 # 1 hour max wait for background jobs
# Session polling
POLLING_INTERVAL_SECONDS = 2
def get_api_key() -> str | None:
"""Get API key from environment in priority order"""
return (
os.environ.get(ENV_LITELLM_API_KEY)
or os.environ.get(ENV_OPENAI_API_KEY)
or os.environ.get(ENV_ANTHROPIC_API_KEY)
)
def get_base_url() -> str | None:
"""Get base URL from OPENAI_BASE_URL environment variable if set"""
base_url = os.environ.get(ENV_OPENAI_BASE_URL)
# Only return if non-empty
return base_url if base_url and base_url.strip() else None

View File

@@ -0,0 +1,501 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "litellm",
# "requests>=2.31.0",
# "tenacity",
# "markitdown>=0.1.0",
# ]
# ///
"""
Consultant CLI - LiteLLM-powered LLM consultation tool
Supports async invocation, custom base URLs, and flexible model selection
Run with: uv run consultant_cli.py [args]
This automatically installs/updates dependencies (litellm, requests) on first run.
"""
import argparse
import json
import sys
from pathlib import Path
# Add scripts directory to path
SCRIPTS_DIR = Path(__file__).parent
sys.path.insert(0, str(SCRIPTS_DIR))
import config
from file_handler import (
FileHandler,
build_multimodal_content,
build_prompt_with_references,
has_images,
validate_vision_support,
)
from litellm_client import LiteLLMClient
from model_selector import ModelSelector
from session_manager import SessionManager
def validate_context_size(
full_prompt: str, model: str, client: LiteLLMClient, num_files: int
) -> bool:
"""
Validate that full prompt fits in model context.
Returns True if OK, raises ValueError if exceeds.
"""
# Count tokens for the complete prompt
total_tokens = client.count_tokens(full_prompt, model)
# Get limit
max_tokens = client.get_max_tokens(model)
# Reserve for response
available_tokens = int(max_tokens * (1 - config.CONTEXT_RESERVE_RATIO))
# Print summary
print("\n📊 Token Usage:")
print(f"- Input: {total_tokens:,} tokens ({num_files} files)")
print(f"- Limit: {max_tokens:,} tokens")
print(
f"- Available: {available_tokens:,} tokens ({int((available_tokens/max_tokens)*100)}%)\n"
)
if total_tokens > max_tokens:
raise ValueError(
f"Input exceeds context limit!\n"
f" Input: {total_tokens:,} tokens\n"
f" Limit: {max_tokens:,} tokens\n"
f" Overage: {total_tokens - max_tokens:,} tokens\n\n"
f"Suggestions:\n"
f"1. Reduce number of files (currently {num_files})\n"
f"2. Use a model with larger context\n"
f"3. Shorten the prompt"
)
if total_tokens > available_tokens:
print(f"⚠️ WARNING: Using {int((total_tokens/max_tokens)*100)}% of context")
print(" Consider reducing input size for better response quality\n")
return True
def handle_invocation(args: argparse.Namespace) -> int:
"""Handle main invocation command"""
# Determine base URL: --base-url flag > OPENAI_BASE_URL env var > None
base_url = args.base_url
if not base_url:
base_url = config.get_base_url()
if base_url:
print(f"Using base URL from OPENAI_BASE_URL: {base_url}")
# Initialize components
session_mgr = SessionManager()
client = LiteLLMClient(base_url=base_url, api_key=args.api_key)
# Process files using FileHandler
file_handler = FileHandler()
processed_files = []
multimodal_content = None
if args.files:
processed_files, file_errors = file_handler.process_files(args.files)
# If any files failed, report errors and exit
if file_errors:
print("\nERROR: Some files could not be processed:", file=sys.stderr)
for err in file_errors:
print(f" - {err.path}: {err.reason}", file=sys.stderr)
print(
"\nPlease fix or remove the problematic files and try again.",
file=sys.stderr,
)
return 1
# Validate vision support if images present
if has_images(processed_files):
validate_vision_support(args.model, has_images=True)
# Print file processing summary
text_count = sum(1 for f in processed_files if f.category.value == "text")
office_count = sum(1 for f in processed_files if f.category.value == "office")
image_count = sum(1 for f in processed_files if f.category.value == "image")
print("\nFile Processing Summary:")
print(f" - Text files: {text_count}")
print(f" - Office documents (converted): {office_count}")
print(f" - Images: {image_count}")
# Log model being used
print(f"Using model: {args.model}")
# Validate environment variables (only if no custom base URL)
if not base_url:
env_status = client.validate_environment(args.model)
if not env_status.get("keys_in_environment", False):
missing = env_status.get("missing_keys", [])
error = env_status.get("error", "")
print(
f"\n❌ ERROR: Missing required environment variables for model '{args.model}'",
file=sys.stderr,
)
print(f"\nMissing keys: {', '.join(missing)}", file=sys.stderr)
if error:
print(f"\nDetails: {error}", file=sys.stderr)
print("\n💡 To fix this:", file=sys.stderr)
print(" 1. Set the required environment variable(s):", file=sys.stderr)
for key in missing:
print(f" export {key}=your-api-key", file=sys.stderr)
print(
" 2. Or use --base-url to specify a custom LiteLLM endpoint",
file=sys.stderr,
)
print(
" 3. Or use --model to specify a different model\n", file=sys.stderr
)
return 1
# Build full prompt with reference files section
full_prompt = build_prompt_with_references(args.prompt, processed_files)
# Build multimodal content if we have images
if has_images(processed_files):
multimodal_content = build_multimodal_content(full_prompt, processed_files)
# Check context limits on the full prompt
try:
validate_context_size(full_prompt, args.model, client, len(processed_files))
except ValueError as e:
print(f"ERROR: {e}", file=sys.stderr)
return 1
# Create and start session
session_id = session_mgr.create_session(
slug=args.slug,
prompt=full_prompt,
model=args.model,
base_url=base_url,
api_key=args.api_key,
reasoning_effort=args.reasoning_effort,
multimodal_content=multimodal_content,
)
print(f"Session created: {session_id}")
print(f"Reattach via: python3 {__file__} session {args.slug}")
print("Waiting for completion...")
try:
result = session_mgr.wait_for_completion(session_id)
if result.get("status") == "completed":
print("\n" + "=" * 80)
print("RESPONSE:")
print("=" * 80)
print(result.get("output", "No output available"))
print("=" * 80)
# Print metadata section (model, reasoning effort, tokens, cost)
print("\n" + "=" * 80)
print("METADATA:")
print("=" * 80)
# Model info
print(f"model: {result.get('model', args.model)}")
print(
f"reasoning_effort: {result.get('reasoning_effort', args.reasoning_effort)}"
)
# Token usage and cost
usage = result.get("usage")
cost_info = result.get("cost_info")
if cost_info:
print(f"input_tokens: {cost_info.get('input_tokens', 0)}")
print(f"output_tokens: {cost_info.get('output_tokens', 0)}")
print(
f"total_tokens: {cost_info.get('input_tokens', 0) + cost_info.get('output_tokens', 0)}"
)
print(f"input_cost_usd: {cost_info.get('input_cost', 0):.6f}")
print(f"output_cost_usd: {cost_info.get('output_cost', 0):.6f}")
print(f"total_cost_usd: {cost_info.get('total_cost', 0):.6f}")
elif usage:
input_tokens = usage.get("prompt_tokens") or usage.get(
"input_tokens", 0
)
output_tokens = usage.get("completion_tokens") or usage.get(
"output_tokens", 0
)
print(f"input_tokens: {input_tokens}")
print(f"output_tokens: {output_tokens}")
print(f"total_tokens: {input_tokens + output_tokens}")
print("=" * 80)
return 0
else:
print(f"\nSession ended with status: {result.get('status')}")
if "error" in result:
print(f"Error: {result['error']}")
return 1
except TimeoutError as e:
print(f"\nERROR: {e}", file=sys.stderr)
return 1
def handle_session_status(args: argparse.Namespace) -> int:
"""Handle session status check"""
session_mgr = SessionManager()
status = session_mgr.get_session_status(args.slug)
if "error" in status and "No session found" in status["error"]:
print(f"ERROR: {status['error']}", file=sys.stderr)
return 1
# Pretty print status
print(json.dumps(status, indent=2))
return 0
def handle_list_sessions(args: argparse.Namespace) -> int:
"""Handle list sessions command"""
session_mgr = SessionManager()
sessions = session_mgr.list_sessions()
if not sessions:
print("No sessions found.")
return 0
print(f"\nFound {len(sessions)} session(s):\n")
for s in sessions:
status_icon = {
"running": "🔄",
"completed": "",
"error": "",
"calling_llm": "📞",
}.get(s.get("status", ""), "")
print(
f"{status_icon} {s.get('slug', 'unknown')} - {s.get('status', 'unknown')}"
)
print(f" Created: {s.get('created_at', 'unknown')}")
print(f" Model: {s.get('model', 'unknown')}")
if s.get("error"):
print(f" Error: {s['error'][:100]}...")
print()
return 0
def handle_list_models(args: argparse.Namespace) -> int:
"""Handle list models command"""
# Determine base URL: --base-url flag > OPENAI_BASE_URL env var > None
base_url = args.base_url
if not base_url:
base_url = config.get_base_url()
if base_url:
print(f"Using base URL from OPENAI_BASE_URL: {base_url}")
LiteLLMClient(base_url=base_url)
models = ModelSelector.list_models(base_url)
print(json.dumps(models, indent=2))
return 0
def main() -> int:
parser = argparse.ArgumentParser(
description="""
Consultant CLI - LiteLLM-powered LLM consultation tool
This CLI tool allows you to consult powerful LLM models for code analysis,
reviews, architectural decisions, and complex technical questions. It supports
100+ LLM providers via LiteLLM with custom base URLs.
CORE WORKFLOW:
1. Provide a prompt describing your analysis task
2. Attach relevant files for context
3. The CLI sends everything to the LLM and waits for completion
4. Results are printed with full metadata (model, tokens, cost)
OUTPUT FORMAT:
The CLI prints structured output with clear sections:
- RESPONSE: The LLM's analysis/response
- METADATA: Model used, reasoning effort, token counts, costs
ENVIRONMENT VARIABLES:
LITELLM_API_KEY Primary API key (checked first)
OPENAI_API_KEY OpenAI API key (fallback)
ANTHROPIC_API_KEY Anthropic API key (fallback)
OPENAI_BASE_URL Default base URL for custom LiteLLM proxy
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
EXAMPLES:
Basic consultation with prompt and files:
%(prog)s -p "Review this code for bugs" -f src/main.py -s code-review
Multiple files:
%(prog)s -p "Analyze architecture" -f src/api.py -f src/db.py -f src/models.py -s arch-review
Specify model explicitly:
%(prog)s -p "Security audit" -f auth.py -s security -m claude-3-5-sonnet-20241022
Use custom LiteLLM proxy:
%(prog)s -p "Code review" -f app.py -s review --base-url http://localhost:8000
Lower reasoning effort (faster, cheaper):
%(prog)s -p "Quick check" -f code.py -s quick --reasoning-effort low
Check session status:
%(prog)s session my-review
List all sessions:
%(prog)s list
List available models from proxy:
%(prog)s models --base-url http://localhost:8000
SUBCOMMANDS:
session <slug> Check status of a session by its slug
list List all sessions with their status
models List available models (from proxy or known models)
For more information, see the consultant plugin documentation.
""",
)
# Subcommands
subparsers = parser.add_subparsers(dest="command", help="Available subcommands")
# Main invocation arguments
parser.add_argument(
"-p",
"--prompt",
metavar="TEXT",
help="""The analysis prompt to send to the LLM. This should describe
what you want the model to analyze or review. The prompt will
be combined with any attached files to form the full request.
REQUIRED for main invocation.""",
)
parser.add_argument(
"-f",
"--file",
action="append",
dest="files",
metavar="PATH",
help="""File to attach for analysis. Can be specified multiple times
to attach multiple files. Each file's contents will be included
in the prompt sent to the LLM. Supports any text file format.
Example: -f src/main.py -f src/utils.py -f README.md""",
)
parser.add_argument(
"-s",
"--slug",
metavar="NAME",
help="""Unique identifier for this session. Used to track and retrieve
session results. Should be descriptive (e.g., "pr-review-123",
"security-audit", "arch-analysis"). REQUIRED for main invocation.""",
)
parser.add_argument(
"-m",
"--model",
metavar="MODEL_ID",
default="gpt-5-pro",
help="""Specific LLM model to use. Default: gpt-5-pro. Examples:
"gpt-5.1", "claude-sonnet-4-5", "gemini/gemini-2.5-flash".
Use the "models" subcommand to see available models.""",
)
parser.add_argument(
"--base-url",
metavar="URL",
help="""Custom base URL for LiteLLM proxy server (e.g., "http://localhost:8000").
When set, all API calls go through this proxy. The proxy's /v1/models
endpoint will be queried for available models. If not set, uses
direct provider APIs based on the model prefix.""",
)
parser.add_argument(
"--api-key",
metavar="KEY",
help="""API key for the LLM provider. If not provided, the CLI will look
for keys in environment variables: LITELLM_API_KEY, OPENAI_API_KEY,
or ANTHROPIC_API_KEY (in that order).""",
)
parser.add_argument(
"--reasoning-effort",
choices=["low", "medium", "high"],
default="high",
metavar="LEVEL",
help="""Reasoning effort level for the LLM. Higher effort = more thorough
analysis but slower and more expensive. Choices: low, medium, high.
Default: high. Use "low" for quick checks, "high" for thorough reviews.""",
)
# Session status subcommand
session_parser = subparsers.add_parser(
"session",
help="Check the status of a session",
description="""Check the current status of a consultation session.
Returns JSON with session metadata, status, and output if completed.""",
)
session_parser.add_argument(
"slug", help="Session slug/identifier to check (the value passed to -s/--slug)"
)
# List sessions subcommand
subparsers.add_parser(
"list",
help="List all consultation sessions",
description="""List all consultation sessions with their status.
Shows session slug, status, creation time, model used, and any errors.""",
)
# List models subcommand
models_parser = subparsers.add_parser(
"models",
help="List available LLM models",
description="""List available LLM models. If --base-url is provided, queries
the proxy's /v1/models endpoint. Otherwise, returns known models
from LiteLLM's model registry.""",
)
models_parser.add_argument(
"--base-url",
metavar="URL",
help="Base URL of LiteLLM proxy to query for available models",
)
args = parser.parse_args()
# Handle commands
if args.command == "session":
return handle_session_status(args)
elif args.command == "list":
return handle_list_sessions(args)
elif args.command == "models":
return handle_list_models(args)
else:
# Main invocation
if not args.prompt or not args.slug:
parser.print_help()
print("\nERROR: --prompt and --slug are required", file=sys.stderr)
return 1
return handle_invocation(args)
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,323 @@
"""
File handling for consultant CLI.
Categorizes and processes files: images, office documents, and text files.
"""
import base64
import mimetypes
import sys
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
from markitdown import MarkItDown
class FileCategory(Enum):
"""Categories of files the CLI can handle"""
IMAGE = "image"
OFFICE = "office"
TEXT = "text"
@dataclass
class ProcessedFile:
"""Result of successfully processing a file"""
path: str
category: FileCategory
content: str = "" # For text/office: the text content
base64_data: str = "" # For images: base64 encoded data
mime_type: str = "" # For images: the MIME type
@dataclass
class FileError:
"""Error details for a file that failed processing"""
path: str
reason: str
# File extension constants
IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp"})
OFFICE_EXTENSIONS = frozenset({".xls", ".xlsx", ".docx", ".pptx"})
# Size limits
MAX_IMAGE_SIZE_BYTES = 20 * 1024 * 1024 # 20MB
class FileHandler:
"""Main file processing coordinator"""
def __init__(self) -> None:
self._markitdown = MarkItDown()
def process_files(
self, file_paths: list[str]
) -> tuple[list[ProcessedFile], list[FileError]]:
"""
Process a list of file paths and return categorized results.
Returns:
Tuple of (successfully processed files, errors)
"""
processed: list[ProcessedFile] = []
errors: list[FileError] = []
for file_path in file_paths:
path = Path(file_path)
# Validate file exists
if not path.exists():
errors.append(FileError(path=str(path), reason="File not found"))
continue
if not path.is_file():
errors.append(FileError(path=str(path), reason="Not a file"))
continue
# Categorize and process
category = self._categorize(path)
if category == FileCategory.IMAGE:
result = self._process_image(path)
elif category == FileCategory.OFFICE:
result = self._process_office(path)
else: # FileCategory.TEXT
result = self._process_text(path)
if isinstance(result, FileError):
errors.append(result)
else:
processed.append(result)
return processed, errors
def _categorize(self, path: Path) -> FileCategory:
"""Determine the category of a file based on extension"""
suffix = path.suffix.lower()
if suffix in IMAGE_EXTENSIONS:
return FileCategory.IMAGE
if suffix in OFFICE_EXTENSIONS:
return FileCategory.OFFICE
# Default: assume text, will validate during processing
return FileCategory.TEXT
def _process_image(self, path: Path) -> ProcessedFile | FileError:
"""Process an image file: validate size and encode to base64"""
try:
# Read binary content
data = path.read_bytes()
# Check size limit
if len(data) > MAX_IMAGE_SIZE_BYTES:
size_mb = len(data) / (1024 * 1024)
max_mb = MAX_IMAGE_SIZE_BYTES / (1024 * 1024)
return FileError(
path=str(path),
reason=f"Image too large: {size_mb:.1f}MB (max {max_mb:.0f}MB)",
)
# Encode to base64
base64_data = base64.b64encode(data).decode("utf-8")
# Determine MIME type
mime_type, _ = mimetypes.guess_type(str(path))
if not mime_type:
# Fallback based on extension
ext = path.suffix.lower()
mime_map = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}
mime_type = mime_map.get(ext, "application/octet-stream")
return ProcessedFile(
path=str(path),
category=FileCategory.IMAGE,
base64_data=base64_data,
mime_type=mime_type,
)
except Exception as e:
return FileError(path=str(path), reason=f"Failed to process image: {e}")
def _process_office(self, path: Path) -> ProcessedFile | FileError:
"""Process an office document using markitdown"""
try:
result = self._markitdown.convert(str(path))
content = result.text_content
if not content or not content.strip():
return FileError(
path=str(path), reason="markitdown returned empty content"
)
return ProcessedFile(
path=str(path),
category=FileCategory.OFFICE,
content=content,
)
except Exception as e:
return FileError(
path=str(path), reason=f"markitdown conversion failed: {e}"
)
def _process_text(self, path: Path) -> ProcessedFile | FileError:
"""Process a text file: attempt UTF-8 decode"""
try:
content = path.read_text(encoding="utf-8")
# Check for empty or whitespace-only files
if not content or not content.strip():
return FileError(
path=str(path),
reason="File is empty or contains only whitespace",
)
return ProcessedFile(
path=str(path),
category=FileCategory.TEXT,
content=content,
)
except UnicodeDecodeError as e:
return FileError(
path=str(path),
reason=f"Not a valid UTF-8 text file: {e}",
)
except Exception as e:
return FileError(path=str(path), reason=f"Failed to read file: {e}")
def validate_vision_support(model: str, has_images: bool) -> None:
"""
Validate that the model supports vision if images are present.
Exits with code 2 if validation fails.
"""
if not has_images:
return
from litellm import supports_vision
if not supports_vision(model=model):
print(
f"\nERROR: Model '{model}' does not support vision/images.\n",
file=sys.stderr,
)
print(
"Image files were provided but the selected model cannot process them.",
file=sys.stderr,
)
print("\nSuggestions:", file=sys.stderr)
print(" 1. Use a vision-capable model:", file=sys.stderr)
print(" - gpt-5.1, gpt-5-vision (OpenAI)", file=sys.stderr)
print(
" - claude-sonnet-4-5, claude-opus-4 (Anthropic)",
file=sys.stderr,
)
print(
" - gemini/gemini-2.5-flash, gemini/gemini-3-pro-preview (Google)", file=sys.stderr
)
print(" 2. Remove image files from the request", file=sys.stderr)
print(" 3. Convert images to text descriptions first\n", file=sys.stderr)
sys.exit(2)
def build_prompt_with_references(prompt: str, files: list[ProcessedFile]) -> str:
"""
Build the text portion of the prompt with Reference Files section.
Does NOT include images (those go in the multimodal array).
Args:
prompt: The user's original prompt
files: List of successfully processed files
Returns:
The full prompt with reference files section appended
"""
# Filter to text and office files only (images handled separately)
text_content_files = [
f for f in files if f.category in (FileCategory.TEXT, FileCategory.OFFICE)
]
# Also get image files for the note
image_files = [f for f in files if f.category == FileCategory.IMAGE]
if not text_content_files and not image_files:
return prompt
parts = [prompt]
# Add reference files section if there are text/office files
if text_content_files:
parts.append("\n\n" + "=" * 80)
parts.append("\n\n## Reference Files\n")
for file in text_content_files:
parts.append(f"\n### {file.path}\n")
parts.append(f"```\n{file.content}\n```\n")
# Add note about images if present
if image_files:
parts.append("\n\n" + "-" * 40)
parts.append(
f"\n*Note: {len(image_files)} image(s) attached for visual analysis.*\n"
)
for img in image_files:
parts.append(f"- {img.path}\n")
return "".join(parts)
def build_multimodal_content(
text_prompt: str, files: list[ProcessedFile]
) -> list[dict[str, Any]]:
"""
Build multimodal content array for LLM APIs.
Uses the standard OpenAI Chat Completions format which is widely supported.
Response strategies will convert to API-specific formats as needed.
Format:
- Text: {"type": "text", "text": "..."}
- Image: {"type": "image_url", "image_url": {"url": "data:...", "detail": "auto"}}
Args:
text_prompt: The text portion of the prompt (with reference files)
files: List of successfully processed files
Returns:
Multimodal content array
"""
content: list[dict[str, Any]] = []
# Text content
content.append({"type": "text", "text": text_prompt})
# Images with base64 data URLs
for f in files:
if f.category == FileCategory.IMAGE:
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{f.mime_type};base64,{f.base64_data}",
"detail": "auto",
},
}
)
return content
def has_images(files: list[ProcessedFile]) -> bool:
"""Check if any processed files are images"""
return any(f.category == FileCategory.IMAGE for f in files)

View File

@@ -0,0 +1,241 @@
"""
LiteLLM client wrapper with token counting and error handling
"""
import os
from pathlib import Path
from typing import Any
import requests
from litellm import (
completion_cost,
get_max_tokens,
token_counter,
validate_environment,
)
from litellm.utils import get_model_info
import config
from response_strategy import ResponseStrategyFactory
class LiteLLMClient:
"""Wrapper around LiteLLM with enhanced functionality"""
def __init__(self, base_url: str | None = None, api_key: str | None = None) -> None:
self.base_url = base_url
self.api_key = api_key or config.get_api_key()
# Configure litellm
if self.api_key:
# Set API key in environment for litellm to pick up
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = self.api_key
def complete(
self,
model: str,
prompt: str,
session_dir: Path | None = None,
reasoning_effort: str = "high",
multimodal_content: list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
Make a request using the responses API with automatic retry/background job handling.
Uses strategy pattern to:
- Use background jobs for OpenAI/Azure (resumable after network failures)
- Use sync with retries for other providers
Args:
model: Model identifier
prompt: Full prompt text
session_dir: Optional session directory for state persistence (enables resumability)
reasoning_effort: Reasoning effort level (low, medium, high) - default high
multimodal_content: Optional multimodal content array for images
**kwargs: Additional args passed to litellm.responses()
Returns:
Dict with 'content' and optional 'usage'
"""
# Add base_url if configured
if self.base_url:
kwargs["api_base"] = self.base_url
# Add reasoning_effort parameter
kwargs["reasoning_effort"] = reasoning_effort
# Select appropriate strategy based on model
strategy = ResponseStrategyFactory.get_strategy(model)
if session_dir:
api_type = ResponseStrategyFactory.get_api_type(model)
print(
f"Using {strategy.__class__.__name__} (resumable: {strategy.can_resume()})"
)
print(f"API: {api_type} | Reasoning effort: {reasoning_effort}")
try:
# Execute with strategy-specific retry/background logic
result: dict[str, Any] = strategy.execute(
model=model,
prompt=prompt,
session_dir=session_dir,
multimodal_content=multimodal_content,
**kwargs,
)
return result
except Exception as e:
# Map to standardized errors
error_msg = str(e)
if "context" in error_msg.lower() or "token" in error_msg.lower():
raise ValueError(f"Context limit exceeded: {error_msg}") from e
elif "auth" in error_msg.lower() or "key" in error_msg.lower():
raise PermissionError(f"Authentication failed: {error_msg}") from e
elif "not found" in error_msg.lower() or "404" in error_msg:
raise ValueError(f"Model not found: {error_msg}") from e
else:
raise RuntimeError(f"LLM request failed: {error_msg}") from e
def count_tokens(self, text: str, model: str) -> int:
"""
Count tokens for given text and model.
When base_url is set (proxy mode), uses the proxy's /utils/token_counter endpoint
for accurate tokenization of custom models. Otherwise uses local token_counter.
"""
# If using a proxy (base_url set), use the proxy's token counter endpoint
if self.base_url:
url = f"{self.base_url.rstrip('/')}/utils/token_counter"
payload = {"model": model, "text": text}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(url, json=payload, headers=headers, timeout=30)
response.raise_for_status()
# Response typically has format: {"token_count": 123}
result = response.json()
token_count = result.get("token_count") or result.get("tokens")
if token_count is None:
raise RuntimeError(
f"Proxy token counter returned invalid response: {result}"
)
return int(token_count)
# Use local token counter (direct API mode)
return int(token_counter(model=model, text=text))
def get_max_tokens(self, model: str) -> int:
"""Get maximum context size for model"""
try:
return int(get_max_tokens(model))
except Exception as e:
# Try get_model_info as alternative method
try:
info = get_model_info(model=model)
max_tokens = info.get("max_tokens")
if max_tokens is None:
raise RuntimeError(
f"Could not determine max_tokens for model {model}"
)
return int(max_tokens)
except Exception as inner_e:
raise RuntimeError(
f"Could not get max tokens for model {model}: {e}, {inner_e}"
) from inner_e
def calculate_cost(
self,
model: str,
response: Any = None,
usage: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
"""
Calculate cost using LiteLLM's built-in completion_cost() function.
Args:
model: Model identifier
response: Optional response object from litellm.responses()
usage: Optional usage dict (fallback if response not available)
Returns:
Dict with input_tokens, output_tokens, costs, or None if unavailable
"""
try:
# Prefer using response object with built-in function
if response:
total_cost = completion_cost(completion_response=response)
# Extract token counts from response.usage if available
if hasattr(response, "usage"):
usage = response.usage
# Calculate from usage dict if provided
if usage:
input_tokens = usage.get("prompt_tokens") or usage.get(
"input_tokens", 0
)
output_tokens = usage.get("completion_tokens") or usage.get(
"output_tokens", 0
)
# Get per-token costs from model info
info = get_model_info(model=model)
input_cost_per_token = info.get("input_cost_per_token", 0)
output_cost_per_token = info.get("output_cost_per_token", 0)
input_cost = input_tokens * input_cost_per_token
output_cost = output_tokens * output_cost_per_token
# Use total_cost from completion_cost if available, else calculate
if not response:
total_cost = input_cost + output_cost
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"input_cost": input_cost,
"output_cost": output_cost,
"total_cost": total_cost,
"currency": "USD",
}
return None
except Exception:
# If we can't get pricing info, return None
return None
def validate_environment(self, model: str) -> dict[str, Any]:
"""
Check if required environment variables are set for the model.
Returns dict with 'keys_in_environment' (bool) and 'missing_keys' (list).
"""
try:
result: dict[str, Any] = validate_environment(model=model)
return result
except Exception as e:
# If validation fails, return a generic response
return {
"keys_in_environment": False,
"missing_keys": ["API_KEY"],
"error": str(e),
}
def test_connection(self, model: str) -> bool:
"""Test if we can connect to the model"""
try:
result = self.complete(model=model, prompt="Hello", max_tokens=5)
return result.get("content") is not None
except Exception:
return False

View File

@@ -0,0 +1,143 @@
"""
Model discovery and selection logic
"""
from typing import Any
import requests
from litellm import model_cost
class ModelSelector:
"""Handles model discovery and automatic selection"""
@staticmethod
def list_models(base_url: str | None = None) -> list[dict[str, Any]]:
"""
Query available models.
Without base_url: Uses LiteLLM's model_cost dictionary for dynamic discovery
With base_url: Calls proxy's /models or /v1/models endpoint
"""
if not base_url:
# Use LiteLLM's model_cost for dynamic discovery
return ModelSelector._get_litellm_models()
# Try LiteLLM proxy /models endpoint first, then OpenAI-compatible /v1/models
last_error = None
for endpoint in ["/models", "/v1/models"]:
try:
models_url = f"{base_url.rstrip('/')}{endpoint}"
response = requests.get(models_url, timeout=10)
response.raise_for_status()
data = response.json()
models = data.get("data", [])
return [
{
"id": m.get("id"),
"created": m.get("created"),
"owned_by": m.get("owned_by"),
}
for m in models
]
except Exception as e:
last_error = e
continue
# If all endpoints fail, raise an error
raise RuntimeError(f"Could not fetch models from {base_url}: {last_error}")
@staticmethod
def select_best_model(base_url: str | None = None) -> str:
"""
Automatically select the best available model.
Heuristic: Prefer models with "large", "pro", or higher version numbers
"""
models = ModelSelector.list_models(base_url)
if not models:
raise RuntimeError("No models available - cannot auto-select model")
# Score models based on name heuristics
best_model = max(models, key=ModelSelector._score_model)
model_id = best_model.get("id")
if not model_id:
raise RuntimeError("Best model has no id - cannot auto-select model")
return str(model_id)
@staticmethod
def _score_model(model: dict[str, Any]) -> float:
"""Score a model based on capabilities (higher is better)"""
model_id = model.get("id", "").lower()
score = 0.0
# Version number scoring
if "gpt-5" in model_id or "o1" in model_id or "o3" in model_id:
score += 50
elif "gpt-4" in model_id:
score += 40
elif "gpt-3.5" in model_id:
score += 30
# Capability indicators
if any(x in model_id for x in ["pro", "turbo", "large", "xl", "ultra"]):
score += 20
# Context size indicators
if "128k" in model_id or "200k" in model_id:
score += 15
elif "32k" in model_id:
score += 12
elif "16k" in model_id:
score += 10
# Anthropic models
if "claude" in model_id:
if "opus" in model_id:
score += 50
elif "sonnet" in model_id:
if "3.5" in model_id or "3-5" in model_id:
score += 48
else:
score += 45
elif "haiku" in model_id:
score += 35
# Google models
if "gemini" in model_id:
if "2.0" in model_id or "2-0" in model_id:
score += 45
elif "pro" in model_id:
score += 40
return score
@staticmethod
def _get_litellm_models() -> list[dict[str, Any]]:
"""
Get models from LiteLLM's model_cost dictionary.
This provides dynamic model discovery without hardcoded lists.
"""
if not model_cost:
raise RuntimeError("LiteLLM model_cost is empty - cannot discover models")
# Convert model_cost dictionary to list format
models = []
for model_id, info in model_cost.items():
models.append(
{
"id": model_id,
"provider": info.get("litellm_provider", "unknown"),
"max_tokens": info.get("max_tokens"),
"input_cost_per_token": info.get("input_cost_per_token"),
"output_cost_per_token": info.get("output_cost_per_token"),
}
)
return models

View File

@@ -0,0 +1,646 @@
"""
Response strategies for different LLM providers.
Handles retries, background jobs, and provider-specific quirks.
Automatically detects responses API vs completions API support.
"""
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import litellm
from litellm import _should_retry, completion, responses
import config
def _is_responses_api_model(model_name: str) -> bool:
"""
Check if a model name indicates responses API support.
Uses general patterns that will work for future model versions:
- GPT-4+ (gpt-4, gpt-5, gpt-6, etc.)
- O-series reasoning models (o1, o2, o3, o4, etc.)
- Codex models
- Computer-use models
Args:
model_name: Model name without provider prefix (lowercase)
Returns:
True if model should use responses API
"""
import re
# GPT-4 and above (gpt-4, gpt-5, gpt-6, etc. but not gpt-3.5)
# Matches: gpt-4, gpt4, gpt-4-turbo, gpt-5.1, gpt-6-turbo, etc.
gpt_match = re.search(r"gpt-?(\d+)", model_name)
if gpt_match:
version = int(gpt_match.group(1))
if version >= 4:
return True
# O-series reasoning models (o1, o2, o3, o4, etc.)
# Matches: o1, o1-pro, o3-mini, o4-preview, etc.
if re.search(r"\bo\d+\b", model_name) or re.search(r"\bo\d+-", model_name):
return True
# Codex models (use responses API)
if "codex" in model_name:
return True
# Computer-use models
return "computer-use" in model_name
def get_responses_api_models() -> set[str]:
"""
Determine which models support the native OpenAI Responses API.
Uses litellm.models_by_provider to get OpenAI models, then filters
to those that support the responses API.
Returns:
Set of model identifiers that support the responses API natively.
"""
responses_models: set[str] = set()
# Get OpenAI models from litellm
openai_models = litellm.models_by_provider.get("openai", [])
azure_models = litellm.models_by_provider.get("azure", [])
for model in openai_models + azure_models:
if _is_responses_api_model(model.lower()):
responses_models.add(model)
responses_models.add(f"openai/{model}")
responses_models.add(f"azure/{model}")
return responses_models
def supports_responses_api(model: str) -> bool:
"""
Check if a model supports the native OpenAI Responses API.
Uses general patterns that work for current and future models:
- GPT-4+ series (gpt-4, gpt-5, gpt-6, etc.)
- O-series reasoning models (o1, o2, o3, etc.)
- Codex models
- Computer-use models
Args:
model: Model identifier (e.g., "openai/gpt-4", "gpt-5-mini")
Returns:
True if model supports responses API natively, False otherwise.
"""
model_lower = model.lower()
# Extract model name and provider
if "/" in model_lower:
provider, model_name = model_lower.split("/", 1)
else:
provider = "openai" # Default provider for bare model names
model_name = model_lower
# Only OpenAI and Azure support the responses API natively
if provider not in ("openai", "azure"):
return False
# Use the generalized pattern matching
return _is_responses_api_model(model_name)
class ResponseStrategy(ABC):
"""Base class for response strategies"""
@abstractmethod
def execute(
self,
model: str,
prompt: str,
session_dir: Path | None = None,
multimodal_content: list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
Execute LLM request with provider-specific strategy.
Returns dict with 'content' and optional 'usage'.
Args:
model: Model identifier
prompt: Text prompt
session_dir: Optional session directory for state persistence
multimodal_content: Optional multimodal content array for images
**kwargs: Additional provider-specific arguments
"""
raise NotImplementedError
@abstractmethod
def can_resume(self) -> bool:
"""Whether this strategy supports resuming after failure"""
raise NotImplementedError
def _calculate_backoff_delay(
self, attempt: int, base_delay: int, max_delay: int
) -> float:
"""Calculate exponential backoff delay with jitter"""
import random
delay = min(base_delay * (2**attempt), max_delay)
# Add 10% jitter to avoid thundering herd
jitter = delay * 0.1 * random.random()
return float(delay + jitter)
def _extract_content(self, response: Any) -> str:
"""
Extract text content from response.output structure.
Handles different output item types:
- ResponseOutputMessage (type='message'): has content with text
- ResponseReasoningItem (type='reasoning'): has summary, no content
"""
content = ""
if hasattr(response, "output") and response.output:
for item in response.output:
# Check item type - only 'message' type has content
item_type = getattr(item, "type", None)
if item_type == "message":
# ResponseOutputMessage: extract text from content
if hasattr(item, "content") and item.content:
for content_item in item.content:
if hasattr(content_item, "text"):
content += content_item.text
# Skip 'reasoning' items (ResponseReasoningItem) - they have summary, not content
return content
def _serialize_usage(self, usage: Any) -> dict[str, Any] | None:
"""
Safely convert usage object to a JSON-serializable dict.
Handles Pydantic models (OpenAI), dataclasses, and plain dicts.
"""
if usage is None:
return None
# Already a dict - return as-is
if isinstance(usage, dict):
return dict(usage)
# Pydantic v2 model
if hasattr(usage, "model_dump"):
result: dict[str, Any] = usage.model_dump()
return result
# Pydantic v1 model
if hasattr(usage, "dict"):
result = usage.dict()
return dict(result)
# Dataclass or object with __dict__
if hasattr(usage, "__dict__"):
return dict(usage.__dict__)
# Last resort - try to convert directly
try:
return dict(usage)
except (TypeError, ValueError):
# If all else fails, return None rather than crash
return None
class BackgroundJobStrategy(ResponseStrategy):
"""
For OpenAI/Azure - uses background jobs with response_id polling.
Supports resuming after network failures by persisting response_id.
"""
def _convert_to_responses_api_format(
self, multimodal_content: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Convert multimodal content from Completions API format to Responses API format.
Completions format: [{"type": "text/image_url", ...}]
Responses format: [{"type": "input_text/input_image", ...}]
"""
converted: list[dict[str, Any]] = []
for item in multimodal_content:
item_type = item.get("type", "")
if item_type == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item_type == "image_url":
# Extract URL from nested object
image_url = item.get("image_url", {})
url = image_url.get("url", "") if isinstance(image_url, dict) else ""
converted.append({"type": "input_image", "image_url": url})
return converted
def execute(
self,
model: str,
prompt: str,
session_dir: Path | None = None,
multimodal_content: list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Execute with background job and polling"""
response_id_file = session_dir / "response_id.txt" if session_dir else None
# Check if we're resuming an existing background job
if response_id_file and response_id_file.exists():
response_id = response_id_file.read_text().strip()
print(f"Resuming background job: {response_id}")
return self._poll_for_completion(response_id)
# Build input - convert multimodal to Responses API format if provided
input_content: str | list[dict[str, Any]]
if multimodal_content:
input_content = self._convert_to_responses_api_format(multimodal_content)
else:
input_content = prompt
# Start new background job
try:
response = responses(
model=model,
input=input_content,
background=True, # Returns immediately with response_id
num_retries=config.MAX_RETRIES, # Use LiteLLM's built-in retries
**kwargs,
)
response_id = response.id
# Persist response_id for resumability
if response_id_file:
response_id_file.write_text(response_id)
print(f"Started background job: {response_id}")
# Poll until complete
return self._poll_for_completion(response_id)
except Exception as e:
# If background mode fails, maybe not supported - raise for fallback
raise RuntimeError(f"Background job failed to start: {e}") from e
def _poll_for_completion(self, response_id: str) -> dict[str, Any]:
"""Poll for completion with exponential backoff and retries"""
start_time = time.time()
attempt = 0
while time.time() - start_time < config.POLL_TIMEOUT:
try:
# Retrieve the response by ID
result = litellm.get_response(response_id=response_id)
if hasattr(result, "status"):
if result.status == "completed":
content = self._extract_content(result)
if not content:
raise RuntimeError("No content in completed response")
return {
"content": content,
"usage": self._serialize_usage(
getattr(result, "usage", None)
),
"response": result, # Include full response for cost calculation
}
elif result.status == "failed":
error = getattr(result, "error", "Unknown error")
raise RuntimeError(f"Background job failed: {error}")
elif result.status in ["in_progress", "queued"]:
# Still processing, wait and retry
time.sleep(config.POLL_INTERVAL)
attempt += 1
continue
else:
# Unknown status, wait and retry
time.sleep(config.POLL_INTERVAL)
continue
else:
# No status field - might be complete already
content = self._extract_content(result)
if content:
return {
"content": content,
"usage": self._serialize_usage(
getattr(result, "usage", None)
),
"response": result, # Include full response for cost calculation
}
# No content, wait and retry
time.sleep(config.POLL_INTERVAL)
continue
except Exception as e:
error_msg = str(e).lower()
# Network errors - retry with backoff
if any(x in error_msg for x in ["network", "timeout", "connection"]):
if attempt < config.MAX_RETRIES:
delay = self._calculate_backoff_delay(
attempt, config.INITIAL_RETRY_DELAY, config.MAX_RETRY_DELAY
)
print(
f"Network error polling job, retrying in {delay:.1f}s... (attempt {attempt + 1}/{config.MAX_RETRIES})"
)
time.sleep(delay)
attempt += 1
continue
else:
raise RuntimeError(
f"Network errors exceeded max retries: {e}"
) from e
# Other errors - raise immediately
raise
raise TimeoutError(
f"Background job {response_id} did not complete within {config.POLL_TIMEOUT}s"
)
def can_resume(self) -> bool:
return True
class SyncRetryStrategy(ResponseStrategy):
"""
For OpenAI/Azure models using responses API - direct sync calls with retry logic.
Cannot resume - must retry from scratch if it fails.
"""
def _convert_to_responses_api_format(
self, multimodal_content: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Convert multimodal content from Completions API format to Responses API format.
Completions format: [{"type": "text/image_url", ...}]
Responses format: [{"type": "input_text/input_image", ...}]
"""
converted: list[dict[str, Any]] = []
for item in multimodal_content:
item_type = item.get("type", "")
if item_type == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item_type == "image_url":
# Extract URL from nested object
image_url = item.get("image_url", {})
url = image_url.get("url", "") if isinstance(image_url, dict) else ""
converted.append({"type": "input_image", "image_url": url})
return converted
def execute(
self,
model: str,
prompt: str,
session_dir: Path | None = None,
multimodal_content: list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Execute with synchronous retries using responses API"""
# Build input - convert multimodal to Responses API format if provided
input_content: str | list[dict[str, Any]]
if multimodal_content:
input_content = self._convert_to_responses_api_format(multimodal_content)
else:
input_content = prompt
for attempt in range(config.MAX_RETRIES):
try:
response = responses(
model=model,
input=input_content,
stream=False,
num_retries=config.MAX_RETRIES, # Use LiteLLM's built-in retries
**kwargs,
)
content = self._extract_content(response)
if not content:
raise RuntimeError("No content in response from LLM")
return {
"content": content,
"usage": self._serialize_usage(getattr(response, "usage", None)),
"response": response, # Include full response for cost calculation
}
except Exception as e:
# Use LiteLLM's built-in retry logic for HTTP errors
if _should_retry and hasattr(e, "status_code"):
retryable = _should_retry(e.status_code)
else:
# Fallback to string matching for non-HTTP errors
error_msg = str(e).lower()
retryable = any(
x in error_msg
for x in [
"network",
"timeout",
"connection",
"429",
"rate limit",
"503",
"overloaded",
]
)
non_retryable = any(
x in error_msg
for x in [
"auth",
"key",
"context",
"token limit",
"not found",
"invalid",
]
)
if non_retryable:
raise
if retryable and attempt < config.MAX_RETRIES - 1:
delay = self._calculate_backoff_delay(
attempt, config.INITIAL_RETRY_DELAY, config.MAX_RETRY_DELAY
)
print(
f"Retryable error, waiting {delay:.1f}s before retry {attempt + 2}/{config.MAX_RETRIES}..."
)
time.sleep(delay)
continue
raise
raise RuntimeError("Max retries exceeded")
def can_resume(self) -> bool:
return False
class CompletionsAPIStrategy(ResponseStrategy):
"""
For Anthropic/Google/other providers - uses chat completions API directly.
More efficient than bridging through responses API for non-OpenAI providers.
"""
def execute(
self,
model: str,
prompt: str,
session_dir: Path | None = None,
multimodal_content: list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Execute with chat completions API"""
# Remove responses-specific kwargs that don't apply to completions
kwargs.pop("reasoning_effort", None)
kwargs.pop("background", None)
# Build message content - use multimodal content if provided, else plain prompt
message_content: str | list[dict[str, Any]] = (
multimodal_content if multimodal_content else prompt
)
for attempt in range(config.MAX_RETRIES):
try:
# Use chat completions API
response = completion(
model=model,
messages=[{"role": "user", "content": message_content}],
stream=False,
num_retries=config.MAX_RETRIES,
**kwargs,
)
# Extract content from chat completion response
content = self._extract_completion_content(response)
if not content:
raise RuntimeError("No content in response from LLM")
return {
"content": content,
"usage": self._serialize_usage(getattr(response, "usage", None)),
"response": response,
}
except Exception as e:
# Use LiteLLM's built-in retry logic for HTTP errors
if _should_retry and hasattr(e, "status_code"):
retryable = _should_retry(e.status_code)
else:
error_msg = str(e).lower()
retryable = any(
x in error_msg
for x in [
"network",
"timeout",
"connection",
"429",
"rate limit",
"503",
"overloaded",
]
)
non_retryable = any(
x in error_msg
for x in [
"auth",
"key",
"context",
"token limit",
"not found",
"invalid",
]
)
if non_retryable:
raise
if retryable and attempt < config.MAX_RETRIES - 1:
delay = self._calculate_backoff_delay(
attempt, config.INITIAL_RETRY_DELAY, config.MAX_RETRY_DELAY
)
print(
f"Retryable error, waiting {delay:.1f}s before retry {attempt + 2}/{config.MAX_RETRIES}..."
)
time.sleep(delay)
continue
raise
raise RuntimeError("Max retries exceeded")
def _extract_completion_content(self, response: Any) -> str:
"""Extract text content from chat completions response"""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and hasattr(choice.message, "content"):
return choice.message.content or ""
return ""
def can_resume(self) -> bool:
return False
class ResponseStrategyFactory:
"""Factory to select appropriate strategy based on model/provider and API support"""
# Models/providers that support background jobs (OpenAI Responses API feature)
BACKGROUND_SUPPORTED = {
"openai/",
"azure/",
}
@staticmethod
def get_strategy(model: str) -> ResponseStrategy:
"""
Select strategy based on model capabilities and API support.
Decision tree:
1. If model supports responses API AND background jobs -> BackgroundJobStrategy
2. If model supports responses API (no background) -> SyncRetryStrategy
3. If model doesn't support responses API -> CompletionsAPIStrategy
Uses litellm.models_by_provider to determine support.
"""
# Check if model supports native responses API
if supports_responses_api(model):
# Check if it also supports background jobs
if ResponseStrategyFactory.supports_background(model):
return BackgroundJobStrategy()
return SyncRetryStrategy()
# For all other providers (Anthropic, Google, Bedrock, etc.)
# Use completions API directly - more efficient than bridging
return CompletionsAPIStrategy()
@staticmethod
def supports_background(model: str) -> bool:
"""Check if model supports background job execution (OpenAI/Azure only)"""
model_lower = model.lower()
return any(
model_lower.startswith(prefix)
for prefix in ResponseStrategyFactory.BACKGROUND_SUPPORTED
)
@staticmethod
def get_api_type(model: str) -> str:
"""
Determine which API type will be used for a given model.
Returns:
'responses' for models using OpenAI Responses API
'completions' for models using Chat Completions API
"""
if supports_responses_api(model):
return "responses"
return "completions"

View File

@@ -0,0 +1,274 @@
"""
Session management for async consultant executions
Handles background processes, session persistence, and status tracking
"""
import builtins
import contextlib
import json
import multiprocessing
import time
from datetime import datetime
from pathlib import Path
from typing import Any
import config
class SessionManager:
"""Manages consultant sessions with async execution"""
def __init__(self, sessions_dir: Path | None = None) -> None:
self.sessions_dir = sessions_dir or config.DEFAULT_SESSIONS_DIR
self.sessions_dir.mkdir(parents=True, exist_ok=True)
def create_session(
self,
slug: str,
prompt: str,
model: str,
base_url: str | None = None,
api_key: str | None = None,
reasoning_effort: str = "high",
multimodal_content: list[dict[str, Any]] | None = None,
) -> str:
"""Create a new session and start background execution"""
session_id = f"{slug}-{int(time.time())}"
session_dir = self.sessions_dir / session_id
session_dir.mkdir(exist_ok=True)
# Save session metadata
metadata = {
"id": session_id,
"slug": slug,
"created_at": datetime.now().isoformat(),
"status": "running",
"model": model,
"base_url": base_url,
"reasoning_effort": reasoning_effort,
"prompt_preview": prompt[:200] + "..." if len(prompt) > 200 else prompt,
"has_images": multimodal_content is not None,
}
metadata_file = session_dir / "metadata.json"
metadata_file.write_text(json.dumps(metadata, indent=2))
# Save full prompt
prompt_file = session_dir / "prompt.txt"
prompt_file.write_text(prompt)
# Start background process
process = multiprocessing.Process(
target=self._execute_session,
args=(
session_id,
prompt,
model,
base_url,
api_key,
reasoning_effort,
multimodal_content,
),
)
process.start()
# Store PID for potential cleanup
(session_dir / "pid").write_text(str(process.pid))
return session_id
def _execute_session(
self,
session_id: str,
prompt: str,
model: str,
base_url: str | None,
api_key: str | None,
reasoning_effort: str = "high",
multimodal_content: list[dict[str, Any]] | None = None,
) -> None:
"""Background execution of LLM consultation"""
session_dir = self.sessions_dir / session_id
try:
# Import here to avoid issues with multiprocessing
from litellm_client import LiteLLMClient
# Initialize client
client = LiteLLMClient(base_url=base_url, api_key=api_key)
# Make LLM call with the full prompt (already includes file contents)
self._update_status(session_id, "calling_llm")
# Get full response (pass session_dir for resumability support)
result = client.complete(
model=model,
prompt=prompt,
session_dir=session_dir, # Enables background job resumption if supported
reasoning_effort=reasoning_effort,
multimodal_content=multimodal_content,
)
full_response = result.get("content", "")
usage = result.get("usage")
response_obj = result.get(
"response"
) # Full response object for cost calculation
# Save response to file
output_file = session_dir / "output.txt"
output_file.write_text(full_response)
# Calculate cost using response object (preferred) or usage dict (fallback)
cost_info = None
if response_obj or usage:
cost_info = client.calculate_cost(
model, response=response_obj, usage=usage
)
# Update metadata with usage and cost
self._update_status(
session_id,
"completed",
response=full_response,
usage=usage,
cost_info=cost_info,
reasoning_effort=reasoning_effort,
)
except Exception as e:
error_msg = f"Error: {str(e)}\n\nType: {type(e).__name__}"
(session_dir / "error.txt").write_text(error_msg)
self._update_status(session_id, "error", error=error_msg)
def _update_status(
self,
session_id: str,
status: str,
response: str | None = None,
error: str | None = None,
usage: dict[str, Any] | None = None,
cost_info: dict[str, Any] | None = None,
reasoning_effort: str | None = None,
) -> None:
"""Update session status in metadata"""
session_dir = self.sessions_dir / session_id
metadata_file = session_dir / "metadata.json"
if not metadata_file.exists():
return
metadata = json.loads(metadata_file.read_text())
metadata["status"] = status
metadata["updated_at"] = datetime.now().isoformat()
if response:
metadata["completed_at"] = datetime.now().isoformat()
metadata["output_length"] = len(response)
if error:
metadata["error"] = error[:500] # Truncate long errors
if usage:
metadata["usage"] = usage
if cost_info:
metadata["cost_info"] = cost_info
if reasoning_effort:
metadata["reasoning_effort"] = reasoning_effort
metadata_file.write_text(json.dumps(metadata, indent=2))
def get_session_status(self, slug: str) -> dict[str, Any]:
"""Get current status of a session by slug"""
# Find most recent session with this slug
matching_sessions = sorted(
[
d
for d in self.sessions_dir.iterdir()
if d.is_dir() and d.name.startswith(slug)
],
key=lambda x: x.stat().st_mtime,
reverse=True,
)
if not matching_sessions:
return {"error": f"No session found with slug: {slug}"}
session_dir = matching_sessions[0]
metadata_file = session_dir / "metadata.json"
if not metadata_file.exists():
return {"error": f"Session metadata not found: {slug}"}
metadata: dict[str, Any] = json.loads(metadata_file.read_text())
# Add output if completed
if metadata["status"] == "completed":
output_file = session_dir / "output.txt"
if output_file.exists():
metadata["output"] = output_file.read_text()
# Add error if failed
if metadata["status"] == "error":
error_file = session_dir / "error.txt"
if error_file.exists():
metadata["error_details"] = error_file.read_text()
return metadata
def wait_for_completion(
self, session_id: str, timeout: int = 3600
) -> dict[str, Any]:
"""Block until session completes or timeout"""
start_time = time.time()
while time.time() - start_time < timeout:
session_dir = self.sessions_dir / session_id
metadata_file = session_dir / "metadata.json"
if not metadata_file.exists():
time.sleep(1)
continue
metadata: dict[str, Any] = json.loads(metadata_file.read_text())
if metadata["status"] in ["completed", "error"]:
# Add output if completed
if metadata["status"] == "completed":
output_file = session_dir / "output.txt"
if output_file.exists():
metadata["output"] = output_file.read_text()
# Add error if failed
if metadata["status"] == "error":
error_file = session_dir / "error.txt"
if error_file.exists():
metadata["error_details"] = error_file.read_text()
return metadata
time.sleep(config.POLLING_INTERVAL_SECONDS)
raise TimeoutError(f"Session {session_id} did not complete within {timeout}s")
def list_sessions(self) -> list[dict[str, Any]]:
"""List all sessions"""
sessions = []
for session_dir in self.sessions_dir.iterdir():
if not session_dir.is_dir():
continue
metadata_file = session_dir / "metadata.json"
if metadata_file.exists():
with contextlib.suppress(builtins.BaseException):
sessions.append(json.loads(metadata_file.read_text()))
return sorted(sessions, key=lambda x: x.get("created_at", ""), reverse=True)