Initial commit
This commit is contained in:
6
skills/consultant/scripts/__init__.py
Normal file
6
skills/consultant/scripts/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Consultant Python Implementation
|
||||
LiteLLM-based tool for flexible multi-provider LLM consultations
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
46
skills/consultant/scripts/config.py
Normal file
46
skills/consultant/scripts/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Configuration and constants for consultant Python implementation
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Session storage location
|
||||
DEFAULT_SESSIONS_DIR = Path.home() / ".consultant" / "sessions"
|
||||
|
||||
# Environment variable names
|
||||
ENV_LITELLM_API_KEY = "LITELLM_API_KEY"
|
||||
ENV_OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||
ENV_ANTHROPIC_API_KEY = "ANTHROPIC_API_KEY"
|
||||
ENV_OPENAI_BASE_URL = "OPENAI_BASE_URL"
|
||||
|
||||
# Token budget: Reserve this percentage for response
|
||||
CONTEXT_RESERVE_RATIO = 0.2 # 20% reserved for response
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
INITIAL_RETRY_DELAY = 2 # seconds
|
||||
MAX_RETRY_DELAY = 60 # seconds
|
||||
|
||||
# Background job polling configuration
|
||||
POLL_INTERVAL = 20 # seconds between polls (configurable)
|
||||
POLL_TIMEOUT = 3600 # 1 hour max wait for background jobs
|
||||
|
||||
# Session polling
|
||||
POLLING_INTERVAL_SECONDS = 2
|
||||
|
||||
|
||||
def get_api_key() -> str | None:
|
||||
"""Get API key from environment in priority order"""
|
||||
return (
|
||||
os.environ.get(ENV_LITELLM_API_KEY)
|
||||
or os.environ.get(ENV_OPENAI_API_KEY)
|
||||
or os.environ.get(ENV_ANTHROPIC_API_KEY)
|
||||
)
|
||||
|
||||
|
||||
def get_base_url() -> str | None:
|
||||
"""Get base URL from OPENAI_BASE_URL environment variable if set"""
|
||||
base_url = os.environ.get(ENV_OPENAI_BASE_URL)
|
||||
# Only return if non-empty
|
||||
return base_url if base_url and base_url.strip() else None
|
||||
501
skills/consultant/scripts/consultant_cli.py
Normal file
501
skills/consultant/scripts/consultant_cli.py
Normal file
@@ -0,0 +1,501 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "litellm",
|
||||
# "requests>=2.31.0",
|
||||
# "tenacity",
|
||||
# "markitdown>=0.1.0",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Consultant CLI - LiteLLM-powered LLM consultation tool
|
||||
Supports async invocation, custom base URLs, and flexible model selection
|
||||
|
||||
Run with: uv run consultant_cli.py [args]
|
||||
This automatically installs/updates dependencies (litellm, requests) on first run.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add scripts directory to path
|
||||
SCRIPTS_DIR = Path(__file__).parent
|
||||
sys.path.insert(0, str(SCRIPTS_DIR))
|
||||
|
||||
import config
|
||||
from file_handler import (
|
||||
FileHandler,
|
||||
build_multimodal_content,
|
||||
build_prompt_with_references,
|
||||
has_images,
|
||||
validate_vision_support,
|
||||
)
|
||||
from litellm_client import LiteLLMClient
|
||||
from model_selector import ModelSelector
|
||||
from session_manager import SessionManager
|
||||
|
||||
|
||||
def validate_context_size(
|
||||
full_prompt: str, model: str, client: LiteLLMClient, num_files: int
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that full prompt fits in model context.
|
||||
Returns True if OK, raises ValueError if exceeds.
|
||||
"""
|
||||
|
||||
# Count tokens for the complete prompt
|
||||
total_tokens = client.count_tokens(full_prompt, model)
|
||||
|
||||
# Get limit
|
||||
max_tokens = client.get_max_tokens(model)
|
||||
|
||||
# Reserve for response
|
||||
available_tokens = int(max_tokens * (1 - config.CONTEXT_RESERVE_RATIO))
|
||||
|
||||
# Print summary
|
||||
print("\n📊 Token Usage:")
|
||||
print(f"- Input: {total_tokens:,} tokens ({num_files} files)")
|
||||
print(f"- Limit: {max_tokens:,} tokens")
|
||||
print(
|
||||
f"- Available: {available_tokens:,} tokens ({int((available_tokens/max_tokens)*100)}%)\n"
|
||||
)
|
||||
|
||||
if total_tokens > max_tokens:
|
||||
raise ValueError(
|
||||
f"Input exceeds context limit!\n"
|
||||
f" Input: {total_tokens:,} tokens\n"
|
||||
f" Limit: {max_tokens:,} tokens\n"
|
||||
f" Overage: {total_tokens - max_tokens:,} tokens\n\n"
|
||||
f"Suggestions:\n"
|
||||
f"1. Reduce number of files (currently {num_files})\n"
|
||||
f"2. Use a model with larger context\n"
|
||||
f"3. Shorten the prompt"
|
||||
)
|
||||
|
||||
if total_tokens > available_tokens:
|
||||
print(f"⚠️ WARNING: Using {int((total_tokens/max_tokens)*100)}% of context")
|
||||
print(" Consider reducing input size for better response quality\n")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def handle_invocation(args: argparse.Namespace) -> int:
|
||||
"""Handle main invocation command"""
|
||||
|
||||
# Determine base URL: --base-url flag > OPENAI_BASE_URL env var > None
|
||||
base_url = args.base_url
|
||||
if not base_url:
|
||||
base_url = config.get_base_url()
|
||||
if base_url:
|
||||
print(f"Using base URL from OPENAI_BASE_URL: {base_url}")
|
||||
|
||||
# Initialize components
|
||||
session_mgr = SessionManager()
|
||||
client = LiteLLMClient(base_url=base_url, api_key=args.api_key)
|
||||
|
||||
# Process files using FileHandler
|
||||
file_handler = FileHandler()
|
||||
processed_files = []
|
||||
multimodal_content = None
|
||||
|
||||
if args.files:
|
||||
processed_files, file_errors = file_handler.process_files(args.files)
|
||||
|
||||
# If any files failed, report errors and exit
|
||||
if file_errors:
|
||||
print("\nERROR: Some files could not be processed:", file=sys.stderr)
|
||||
for err in file_errors:
|
||||
print(f" - {err.path}: {err.reason}", file=sys.stderr)
|
||||
print(
|
||||
"\nPlease fix or remove the problematic files and try again.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
# Validate vision support if images present
|
||||
if has_images(processed_files):
|
||||
validate_vision_support(args.model, has_images=True)
|
||||
|
||||
# Print file processing summary
|
||||
text_count = sum(1 for f in processed_files if f.category.value == "text")
|
||||
office_count = sum(1 for f in processed_files if f.category.value == "office")
|
||||
image_count = sum(1 for f in processed_files if f.category.value == "image")
|
||||
|
||||
print("\nFile Processing Summary:")
|
||||
print(f" - Text files: {text_count}")
|
||||
print(f" - Office documents (converted): {office_count}")
|
||||
print(f" - Images: {image_count}")
|
||||
|
||||
# Log model being used
|
||||
print(f"Using model: {args.model}")
|
||||
|
||||
# Validate environment variables (only if no custom base URL)
|
||||
if not base_url:
|
||||
env_status = client.validate_environment(args.model)
|
||||
if not env_status.get("keys_in_environment", False):
|
||||
missing = env_status.get("missing_keys", [])
|
||||
error = env_status.get("error", "")
|
||||
|
||||
print(
|
||||
f"\n❌ ERROR: Missing required environment variables for model '{args.model}'",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f"\nMissing keys: {', '.join(missing)}", file=sys.stderr)
|
||||
|
||||
if error:
|
||||
print(f"\nDetails: {error}", file=sys.stderr)
|
||||
|
||||
print("\n💡 To fix this:", file=sys.stderr)
|
||||
print(" 1. Set the required environment variable(s):", file=sys.stderr)
|
||||
for key in missing:
|
||||
print(f" export {key}=your-api-key", file=sys.stderr)
|
||||
print(
|
||||
" 2. Or use --base-url to specify a custom LiteLLM endpoint",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
" 3. Or use --model to specify a different model\n", file=sys.stderr
|
||||
)
|
||||
|
||||
return 1
|
||||
|
||||
# Build full prompt with reference files section
|
||||
full_prompt = build_prompt_with_references(args.prompt, processed_files)
|
||||
|
||||
# Build multimodal content if we have images
|
||||
if has_images(processed_files):
|
||||
multimodal_content = build_multimodal_content(full_prompt, processed_files)
|
||||
|
||||
# Check context limits on the full prompt
|
||||
try:
|
||||
validate_context_size(full_prompt, args.model, client, len(processed_files))
|
||||
except ValueError as e:
|
||||
print(f"ERROR: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Create and start session
|
||||
session_id = session_mgr.create_session(
|
||||
slug=args.slug,
|
||||
prompt=full_prompt,
|
||||
model=args.model,
|
||||
base_url=base_url,
|
||||
api_key=args.api_key,
|
||||
reasoning_effort=args.reasoning_effort,
|
||||
multimodal_content=multimodal_content,
|
||||
)
|
||||
|
||||
print(f"Session created: {session_id}")
|
||||
print(f"Reattach via: python3 {__file__} session {args.slug}")
|
||||
print("Waiting for completion...")
|
||||
|
||||
try:
|
||||
result = session_mgr.wait_for_completion(session_id)
|
||||
|
||||
if result.get("status") == "completed":
|
||||
print("\n" + "=" * 80)
|
||||
print("RESPONSE:")
|
||||
print("=" * 80)
|
||||
print(result.get("output", "No output available"))
|
||||
print("=" * 80)
|
||||
|
||||
# Print metadata section (model, reasoning effort, tokens, cost)
|
||||
print("\n" + "=" * 80)
|
||||
print("METADATA:")
|
||||
print("=" * 80)
|
||||
|
||||
# Model info
|
||||
print(f"model: {result.get('model', args.model)}")
|
||||
print(
|
||||
f"reasoning_effort: {result.get('reasoning_effort', args.reasoning_effort)}"
|
||||
)
|
||||
|
||||
# Token usage and cost
|
||||
usage = result.get("usage")
|
||||
cost_info = result.get("cost_info")
|
||||
|
||||
if cost_info:
|
||||
print(f"input_tokens: {cost_info.get('input_tokens', 0)}")
|
||||
print(f"output_tokens: {cost_info.get('output_tokens', 0)}")
|
||||
print(
|
||||
f"total_tokens: {cost_info.get('input_tokens', 0) + cost_info.get('output_tokens', 0)}"
|
||||
)
|
||||
print(f"input_cost_usd: {cost_info.get('input_cost', 0):.6f}")
|
||||
print(f"output_cost_usd: {cost_info.get('output_cost', 0):.6f}")
|
||||
print(f"total_cost_usd: {cost_info.get('total_cost', 0):.6f}")
|
||||
elif usage:
|
||||
input_tokens = usage.get("prompt_tokens") or usage.get(
|
||||
"input_tokens", 0
|
||||
)
|
||||
output_tokens = usage.get("completion_tokens") or usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
print(f"input_tokens: {input_tokens}")
|
||||
print(f"output_tokens: {output_tokens}")
|
||||
print(f"total_tokens: {input_tokens + output_tokens}")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
return 0
|
||||
else:
|
||||
print(f"\nSession ended with status: {result.get('status')}")
|
||||
if "error" in result:
|
||||
print(f"Error: {result['error']}")
|
||||
return 1
|
||||
|
||||
except TimeoutError as e:
|
||||
print(f"\nERROR: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
def handle_session_status(args: argparse.Namespace) -> int:
|
||||
"""Handle session status check"""
|
||||
|
||||
session_mgr = SessionManager()
|
||||
status = session_mgr.get_session_status(args.slug)
|
||||
|
||||
if "error" in status and "No session found" in status["error"]:
|
||||
print(f"ERROR: {status['error']}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Pretty print status
|
||||
print(json.dumps(status, indent=2))
|
||||
return 0
|
||||
|
||||
|
||||
def handle_list_sessions(args: argparse.Namespace) -> int:
|
||||
"""Handle list sessions command"""
|
||||
|
||||
session_mgr = SessionManager()
|
||||
sessions = session_mgr.list_sessions()
|
||||
|
||||
if not sessions:
|
||||
print("No sessions found.")
|
||||
return 0
|
||||
|
||||
print(f"\nFound {len(sessions)} session(s):\n")
|
||||
for s in sessions:
|
||||
status_icon = {
|
||||
"running": "🔄",
|
||||
"completed": "✅",
|
||||
"error": "❌",
|
||||
"calling_llm": "📞",
|
||||
}.get(s.get("status", ""), "❓")
|
||||
|
||||
print(
|
||||
f"{status_icon} {s.get('slug', 'unknown')} - {s.get('status', 'unknown')}"
|
||||
)
|
||||
print(f" Created: {s.get('created_at', 'unknown')}")
|
||||
print(f" Model: {s.get('model', 'unknown')}")
|
||||
if s.get("error"):
|
||||
print(f" Error: {s['error'][:100]}...")
|
||||
print()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def handle_list_models(args: argparse.Namespace) -> int:
|
||||
"""Handle list models command"""
|
||||
|
||||
# Determine base URL: --base-url flag > OPENAI_BASE_URL env var > None
|
||||
base_url = args.base_url
|
||||
if not base_url:
|
||||
base_url = config.get_base_url()
|
||||
if base_url:
|
||||
print(f"Using base URL from OPENAI_BASE_URL: {base_url}")
|
||||
|
||||
LiteLLMClient(base_url=base_url)
|
||||
models = ModelSelector.list_models(base_url)
|
||||
|
||||
print(json.dumps(models, indent=2))
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
Consultant CLI - LiteLLM-powered LLM consultation tool
|
||||
|
||||
This CLI tool allows you to consult powerful LLM models for code analysis,
|
||||
reviews, architectural decisions, and complex technical questions. It supports
|
||||
100+ LLM providers via LiteLLM with custom base URLs.
|
||||
|
||||
CORE WORKFLOW:
|
||||
1. Provide a prompt describing your analysis task
|
||||
2. Attach relevant files for context
|
||||
3. The CLI sends everything to the LLM and waits for completion
|
||||
4. Results are printed with full metadata (model, tokens, cost)
|
||||
|
||||
OUTPUT FORMAT:
|
||||
The CLI prints structured output with clear sections:
|
||||
- RESPONSE: The LLM's analysis/response
|
||||
- METADATA: Model used, reasoning effort, token counts, costs
|
||||
|
||||
ENVIRONMENT VARIABLES:
|
||||
LITELLM_API_KEY Primary API key (checked first)
|
||||
OPENAI_API_KEY OpenAI API key (fallback)
|
||||
ANTHROPIC_API_KEY Anthropic API key (fallback)
|
||||
OPENAI_BASE_URL Default base URL for custom LiteLLM proxy
|
||||
""",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
EXAMPLES:
|
||||
|
||||
Basic consultation with prompt and files:
|
||||
%(prog)s -p "Review this code for bugs" -f src/main.py -s code-review
|
||||
|
||||
Multiple files:
|
||||
%(prog)s -p "Analyze architecture" -f src/api.py -f src/db.py -f src/models.py -s arch-review
|
||||
|
||||
Specify model explicitly:
|
||||
%(prog)s -p "Security audit" -f auth.py -s security -m claude-3-5-sonnet-20241022
|
||||
|
||||
Use custom LiteLLM proxy:
|
||||
%(prog)s -p "Code review" -f app.py -s review --base-url http://localhost:8000
|
||||
|
||||
Lower reasoning effort (faster, cheaper):
|
||||
%(prog)s -p "Quick check" -f code.py -s quick --reasoning-effort low
|
||||
|
||||
Check session status:
|
||||
%(prog)s session my-review
|
||||
|
||||
List all sessions:
|
||||
%(prog)s list
|
||||
|
||||
List available models from proxy:
|
||||
%(prog)s models --base-url http://localhost:8000
|
||||
|
||||
SUBCOMMANDS:
|
||||
session <slug> Check status of a session by its slug
|
||||
list List all sessions with their status
|
||||
models List available models (from proxy or known models)
|
||||
|
||||
For more information, see the consultant plugin documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
# Subcommands
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available subcommands")
|
||||
|
||||
# Main invocation arguments
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--prompt",
|
||||
metavar="TEXT",
|
||||
help="""The analysis prompt to send to the LLM. This should describe
|
||||
what you want the model to analyze or review. The prompt will
|
||||
be combined with any attached files to form the full request.
|
||||
REQUIRED for main invocation.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--file",
|
||||
action="append",
|
||||
dest="files",
|
||||
metavar="PATH",
|
||||
help="""File to attach for analysis. Can be specified multiple times
|
||||
to attach multiple files. Each file's contents will be included
|
||||
in the prompt sent to the LLM. Supports any text file format.
|
||||
Example: -f src/main.py -f src/utils.py -f README.md""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--slug",
|
||||
metavar="NAME",
|
||||
help="""Unique identifier for this session. Used to track and retrieve
|
||||
session results. Should be descriptive (e.g., "pr-review-123",
|
||||
"security-audit", "arch-analysis"). REQUIRED for main invocation.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
metavar="MODEL_ID",
|
||||
default="gpt-5-pro",
|
||||
help="""Specific LLM model to use. Default: gpt-5-pro. Examples:
|
||||
"gpt-5.1", "claude-sonnet-4-5", "gemini/gemini-2.5-flash".
|
||||
Use the "models" subcommand to see available models.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
metavar="URL",
|
||||
help="""Custom base URL for LiteLLM proxy server (e.g., "http://localhost:8000").
|
||||
When set, all API calls go through this proxy. The proxy's /v1/models
|
||||
endpoint will be queried for available models. If not set, uses
|
||||
direct provider APIs based on the model prefix.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
metavar="KEY",
|
||||
help="""API key for the LLM provider. If not provided, the CLI will look
|
||||
for keys in environment variables: LITELLM_API_KEY, OPENAI_API_KEY,
|
||||
or ANTHROPIC_API_KEY (in that order).""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reasoning-effort",
|
||||
choices=["low", "medium", "high"],
|
||||
default="high",
|
||||
metavar="LEVEL",
|
||||
help="""Reasoning effort level for the LLM. Higher effort = more thorough
|
||||
analysis but slower and more expensive. Choices: low, medium, high.
|
||||
Default: high. Use "low" for quick checks, "high" for thorough reviews.""",
|
||||
)
|
||||
|
||||
# Session status subcommand
|
||||
session_parser = subparsers.add_parser(
|
||||
"session",
|
||||
help="Check the status of a session",
|
||||
description="""Check the current status of a consultation session.
|
||||
Returns JSON with session metadata, status, and output if completed.""",
|
||||
)
|
||||
session_parser.add_argument(
|
||||
"slug", help="Session slug/identifier to check (the value passed to -s/--slug)"
|
||||
)
|
||||
|
||||
# List sessions subcommand
|
||||
subparsers.add_parser(
|
||||
"list",
|
||||
help="List all consultation sessions",
|
||||
description="""List all consultation sessions with their status.
|
||||
Shows session slug, status, creation time, model used, and any errors.""",
|
||||
)
|
||||
|
||||
# List models subcommand
|
||||
models_parser = subparsers.add_parser(
|
||||
"models",
|
||||
help="List available LLM models",
|
||||
description="""List available LLM models. If --base-url is provided, queries
|
||||
the proxy's /v1/models endpoint. Otherwise, returns known models
|
||||
from LiteLLM's model registry.""",
|
||||
)
|
||||
models_parser.add_argument(
|
||||
"--base-url",
|
||||
metavar="URL",
|
||||
help="Base URL of LiteLLM proxy to query for available models",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle commands
|
||||
if args.command == "session":
|
||||
return handle_session_status(args)
|
||||
|
||||
elif args.command == "list":
|
||||
return handle_list_sessions(args)
|
||||
|
||||
elif args.command == "models":
|
||||
return handle_list_models(args)
|
||||
|
||||
else:
|
||||
# Main invocation
|
||||
if not args.prompt or not args.slug:
|
||||
parser.print_help()
|
||||
print("\nERROR: --prompt and --slug are required", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
return handle_invocation(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
323
skills/consultant/scripts/file_handler.py
Normal file
323
skills/consultant/scripts/file_handler.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
File handling for consultant CLI.
|
||||
Categorizes and processes files: images, office documents, and text files.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from markitdown import MarkItDown
|
||||
|
||||
|
||||
class FileCategory(Enum):
|
||||
"""Categories of files the CLI can handle"""
|
||||
|
||||
IMAGE = "image"
|
||||
OFFICE = "office"
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessedFile:
|
||||
"""Result of successfully processing a file"""
|
||||
|
||||
path: str
|
||||
category: FileCategory
|
||||
content: str = "" # For text/office: the text content
|
||||
base64_data: str = "" # For images: base64 encoded data
|
||||
mime_type: str = "" # For images: the MIME type
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileError:
|
||||
"""Error details for a file that failed processing"""
|
||||
|
||||
path: str
|
||||
reason: str
|
||||
|
||||
|
||||
# File extension constants
|
||||
IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp"})
|
||||
OFFICE_EXTENSIONS = frozenset({".xls", ".xlsx", ".docx", ".pptx"})
|
||||
|
||||
# Size limits
|
||||
MAX_IMAGE_SIZE_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
|
||||
|
||||
class FileHandler:
|
||||
"""Main file processing coordinator"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._markitdown = MarkItDown()
|
||||
|
||||
def process_files(
|
||||
self, file_paths: list[str]
|
||||
) -> tuple[list[ProcessedFile], list[FileError]]:
|
||||
"""
|
||||
Process a list of file paths and return categorized results.
|
||||
|
||||
Returns:
|
||||
Tuple of (successfully processed files, errors)
|
||||
"""
|
||||
processed: list[ProcessedFile] = []
|
||||
errors: list[FileError] = []
|
||||
|
||||
for file_path in file_paths:
|
||||
path = Path(file_path)
|
||||
|
||||
# Validate file exists
|
||||
if not path.exists():
|
||||
errors.append(FileError(path=str(path), reason="File not found"))
|
||||
continue
|
||||
|
||||
if not path.is_file():
|
||||
errors.append(FileError(path=str(path), reason="Not a file"))
|
||||
continue
|
||||
|
||||
# Categorize and process
|
||||
category = self._categorize(path)
|
||||
|
||||
if category == FileCategory.IMAGE:
|
||||
result = self._process_image(path)
|
||||
elif category == FileCategory.OFFICE:
|
||||
result = self._process_office(path)
|
||||
else: # FileCategory.TEXT
|
||||
result = self._process_text(path)
|
||||
|
||||
if isinstance(result, FileError):
|
||||
errors.append(result)
|
||||
else:
|
||||
processed.append(result)
|
||||
|
||||
return processed, errors
|
||||
|
||||
def _categorize(self, path: Path) -> FileCategory:
|
||||
"""Determine the category of a file based on extension"""
|
||||
suffix = path.suffix.lower()
|
||||
|
||||
if suffix in IMAGE_EXTENSIONS:
|
||||
return FileCategory.IMAGE
|
||||
|
||||
if suffix in OFFICE_EXTENSIONS:
|
||||
return FileCategory.OFFICE
|
||||
|
||||
# Default: assume text, will validate during processing
|
||||
return FileCategory.TEXT
|
||||
|
||||
def _process_image(self, path: Path) -> ProcessedFile | FileError:
|
||||
"""Process an image file: validate size and encode to base64"""
|
||||
try:
|
||||
# Read binary content
|
||||
data = path.read_bytes()
|
||||
|
||||
# Check size limit
|
||||
if len(data) > MAX_IMAGE_SIZE_BYTES:
|
||||
size_mb = len(data) / (1024 * 1024)
|
||||
max_mb = MAX_IMAGE_SIZE_BYTES / (1024 * 1024)
|
||||
return FileError(
|
||||
path=str(path),
|
||||
reason=f"Image too large: {size_mb:.1f}MB (max {max_mb:.0f}MB)",
|
||||
)
|
||||
|
||||
# Encode to base64
|
||||
base64_data = base64.b64encode(data).decode("utf-8")
|
||||
|
||||
# Determine MIME type
|
||||
mime_type, _ = mimetypes.guess_type(str(path))
|
||||
if not mime_type:
|
||||
# Fallback based on extension
|
||||
ext = path.suffix.lower()
|
||||
mime_map = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
}
|
||||
mime_type = mime_map.get(ext, "application/octet-stream")
|
||||
|
||||
return ProcessedFile(
|
||||
path=str(path),
|
||||
category=FileCategory.IMAGE,
|
||||
base64_data=base64_data,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
except Exception as e:
|
||||
return FileError(path=str(path), reason=f"Failed to process image: {e}")
|
||||
|
||||
def _process_office(self, path: Path) -> ProcessedFile | FileError:
|
||||
"""Process an office document using markitdown"""
|
||||
try:
|
||||
result = self._markitdown.convert(str(path))
|
||||
content = result.text_content
|
||||
|
||||
if not content or not content.strip():
|
||||
return FileError(
|
||||
path=str(path), reason="markitdown returned empty content"
|
||||
)
|
||||
|
||||
return ProcessedFile(
|
||||
path=str(path),
|
||||
category=FileCategory.OFFICE,
|
||||
content=content,
|
||||
)
|
||||
except Exception as e:
|
||||
return FileError(
|
||||
path=str(path), reason=f"markitdown conversion failed: {e}"
|
||||
)
|
||||
|
||||
def _process_text(self, path: Path) -> ProcessedFile | FileError:
|
||||
"""Process a text file: attempt UTF-8 decode"""
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
|
||||
# Check for empty or whitespace-only files
|
||||
if not content or not content.strip():
|
||||
return FileError(
|
||||
path=str(path),
|
||||
reason="File is empty or contains only whitespace",
|
||||
)
|
||||
|
||||
return ProcessedFile(
|
||||
path=str(path),
|
||||
category=FileCategory.TEXT,
|
||||
content=content,
|
||||
)
|
||||
except UnicodeDecodeError as e:
|
||||
return FileError(
|
||||
path=str(path),
|
||||
reason=f"Not a valid UTF-8 text file: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
return FileError(path=str(path), reason=f"Failed to read file: {e}")
|
||||
|
||||
|
||||
def validate_vision_support(model: str, has_images: bool) -> None:
|
||||
"""
|
||||
Validate that the model supports vision if images are present.
|
||||
Exits with code 2 if validation fails.
|
||||
"""
|
||||
if not has_images:
|
||||
return
|
||||
|
||||
from litellm import supports_vision
|
||||
|
||||
if not supports_vision(model=model):
|
||||
print(
|
||||
f"\nERROR: Model '{model}' does not support vision/images.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
"Image files were provided but the selected model cannot process them.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print("\nSuggestions:", file=sys.stderr)
|
||||
print(" 1. Use a vision-capable model:", file=sys.stderr)
|
||||
print(" - gpt-5.1, gpt-5-vision (OpenAI)", file=sys.stderr)
|
||||
print(
|
||||
" - claude-sonnet-4-5, claude-opus-4 (Anthropic)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
" - gemini/gemini-2.5-flash, gemini/gemini-3-pro-preview (Google)", file=sys.stderr
|
||||
)
|
||||
print(" 2. Remove image files from the request", file=sys.stderr)
|
||||
print(" 3. Convert images to text descriptions first\n", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
|
||||
def build_prompt_with_references(prompt: str, files: list[ProcessedFile]) -> str:
|
||||
"""
|
||||
Build the text portion of the prompt with Reference Files section.
|
||||
Does NOT include images (those go in the multimodal array).
|
||||
|
||||
Args:
|
||||
prompt: The user's original prompt
|
||||
files: List of successfully processed files
|
||||
|
||||
Returns:
|
||||
The full prompt with reference files section appended
|
||||
"""
|
||||
# Filter to text and office files only (images handled separately)
|
||||
text_content_files = [
|
||||
f for f in files if f.category in (FileCategory.TEXT, FileCategory.OFFICE)
|
||||
]
|
||||
|
||||
# Also get image files for the note
|
||||
image_files = [f for f in files if f.category == FileCategory.IMAGE]
|
||||
|
||||
if not text_content_files and not image_files:
|
||||
return prompt
|
||||
|
||||
parts = [prompt]
|
||||
|
||||
# Add reference files section if there are text/office files
|
||||
if text_content_files:
|
||||
parts.append("\n\n" + "=" * 80)
|
||||
parts.append("\n\n## Reference Files\n")
|
||||
|
||||
for file in text_content_files:
|
||||
parts.append(f"\n### {file.path}\n")
|
||||
parts.append(f"```\n{file.content}\n```\n")
|
||||
|
||||
# Add note about images if present
|
||||
if image_files:
|
||||
parts.append("\n\n" + "-" * 40)
|
||||
parts.append(
|
||||
f"\n*Note: {len(image_files)} image(s) attached for visual analysis.*\n"
|
||||
)
|
||||
for img in image_files:
|
||||
parts.append(f"- {img.path}\n")
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def build_multimodal_content(
|
||||
text_prompt: str, files: list[ProcessedFile]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Build multimodal content array for LLM APIs.
|
||||
|
||||
Uses the standard OpenAI Chat Completions format which is widely supported.
|
||||
Response strategies will convert to API-specific formats as needed.
|
||||
|
||||
Format:
|
||||
- Text: {"type": "text", "text": "..."}
|
||||
- Image: {"type": "image_url", "image_url": {"url": "data:...", "detail": "auto"}}
|
||||
|
||||
Args:
|
||||
text_prompt: The text portion of the prompt (with reference files)
|
||||
files: List of successfully processed files
|
||||
|
||||
Returns:
|
||||
Multimodal content array
|
||||
"""
|
||||
content: list[dict[str, Any]] = []
|
||||
|
||||
# Text content
|
||||
content.append({"type": "text", "text": text_prompt})
|
||||
|
||||
# Images with base64 data URLs
|
||||
for f in files:
|
||||
if f.category == FileCategory.IMAGE:
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{f.mime_type};base64,{f.base64_data}",
|
||||
"detail": "auto",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def has_images(files: list[ProcessedFile]) -> bool:
|
||||
"""Check if any processed files are images"""
|
||||
return any(f.category == FileCategory.IMAGE for f in files)
|
||||
241
skills/consultant/scripts/litellm_client.py
Normal file
241
skills/consultant/scripts/litellm_client.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
LiteLLM client wrapper with token counting and error handling
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from litellm import (
|
||||
completion_cost,
|
||||
get_max_tokens,
|
||||
token_counter,
|
||||
validate_environment,
|
||||
)
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
import config
|
||||
from response_strategy import ResponseStrategyFactory
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
"""Wrapper around LiteLLM with enhanced functionality"""
|
||||
|
||||
def __init__(self, base_url: str | None = None, api_key: str | None = None) -> None:
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key or config.get_api_key()
|
||||
|
||||
# Configure litellm
|
||||
if self.api_key:
|
||||
# Set API key in environment for litellm to pick up
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = self.api_key
|
||||
|
||||
def complete(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
session_dir: Path | None = None,
|
||||
reasoning_effort: str = "high",
|
||||
multimodal_content: list[dict[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Make a request using the responses API with automatic retry/background job handling.
|
||||
|
||||
Uses strategy pattern to:
|
||||
- Use background jobs for OpenAI/Azure (resumable after network failures)
|
||||
- Use sync with retries for other providers
|
||||
|
||||
Args:
|
||||
model: Model identifier
|
||||
prompt: Full prompt text
|
||||
session_dir: Optional session directory for state persistence (enables resumability)
|
||||
reasoning_effort: Reasoning effort level (low, medium, high) - default high
|
||||
multimodal_content: Optional multimodal content array for images
|
||||
**kwargs: Additional args passed to litellm.responses()
|
||||
|
||||
Returns:
|
||||
Dict with 'content' and optional 'usage'
|
||||
"""
|
||||
|
||||
# Add base_url if configured
|
||||
if self.base_url:
|
||||
kwargs["api_base"] = self.base_url
|
||||
|
||||
# Add reasoning_effort parameter
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
|
||||
# Select appropriate strategy based on model
|
||||
strategy = ResponseStrategyFactory.get_strategy(model)
|
||||
|
||||
if session_dir:
|
||||
api_type = ResponseStrategyFactory.get_api_type(model)
|
||||
print(
|
||||
f"Using {strategy.__class__.__name__} (resumable: {strategy.can_resume()})"
|
||||
)
|
||||
print(f"API: {api_type} | Reasoning effort: {reasoning_effort}")
|
||||
|
||||
try:
|
||||
# Execute with strategy-specific retry/background logic
|
||||
result: dict[str, Any] = strategy.execute(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
session_dir=session_dir,
|
||||
multimodal_content=multimodal_content,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Map to standardized errors
|
||||
error_msg = str(e)
|
||||
|
||||
if "context" in error_msg.lower() or "token" in error_msg.lower():
|
||||
raise ValueError(f"Context limit exceeded: {error_msg}") from e
|
||||
elif "auth" in error_msg.lower() or "key" in error_msg.lower():
|
||||
raise PermissionError(f"Authentication failed: {error_msg}") from e
|
||||
elif "not found" in error_msg.lower() or "404" in error_msg:
|
||||
raise ValueError(f"Model not found: {error_msg}") from e
|
||||
else:
|
||||
raise RuntimeError(f"LLM request failed: {error_msg}") from e
|
||||
|
||||
def count_tokens(self, text: str, model: str) -> int:
|
||||
"""
|
||||
Count tokens for given text and model.
|
||||
|
||||
When base_url is set (proxy mode), uses the proxy's /utils/token_counter endpoint
|
||||
for accurate tokenization of custom models. Otherwise uses local token_counter.
|
||||
"""
|
||||
|
||||
# If using a proxy (base_url set), use the proxy's token counter endpoint
|
||||
if self.base_url:
|
||||
url = f"{self.base_url.rstrip('/')}/utils/token_counter"
|
||||
payload = {"model": model, "text": text}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
# Response typically has format: {"token_count": 123}
|
||||
result = response.json()
|
||||
token_count = result.get("token_count") or result.get("tokens")
|
||||
if token_count is None:
|
||||
raise RuntimeError(
|
||||
f"Proxy token counter returned invalid response: {result}"
|
||||
)
|
||||
return int(token_count)
|
||||
|
||||
# Use local token counter (direct API mode)
|
||||
return int(token_counter(model=model, text=text))
|
||||
|
||||
def get_max_tokens(self, model: str) -> int:
|
||||
"""Get maximum context size for model"""
|
||||
|
||||
try:
|
||||
return int(get_max_tokens(model))
|
||||
except Exception as e:
|
||||
# Try get_model_info as alternative method
|
||||
try:
|
||||
info = get_model_info(model=model)
|
||||
max_tokens = info.get("max_tokens")
|
||||
if max_tokens is None:
|
||||
raise RuntimeError(
|
||||
f"Could not determine max_tokens for model {model}"
|
||||
)
|
||||
return int(max_tokens)
|
||||
except Exception as inner_e:
|
||||
raise RuntimeError(
|
||||
f"Could not get max tokens for model {model}: {e}, {inner_e}"
|
||||
) from inner_e
|
||||
|
||||
def calculate_cost(
|
||||
self,
|
||||
model: str,
|
||||
response: Any = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Calculate cost using LiteLLM's built-in completion_cost() function.
|
||||
|
||||
Args:
|
||||
model: Model identifier
|
||||
response: Optional response object from litellm.responses()
|
||||
usage: Optional usage dict (fallback if response not available)
|
||||
|
||||
Returns:
|
||||
Dict with input_tokens, output_tokens, costs, or None if unavailable
|
||||
"""
|
||||
try:
|
||||
# Prefer using response object with built-in function
|
||||
if response:
|
||||
total_cost = completion_cost(completion_response=response)
|
||||
|
||||
# Extract token counts from response.usage if available
|
||||
if hasattr(response, "usage"):
|
||||
usage = response.usage
|
||||
|
||||
# Calculate from usage dict if provided
|
||||
if usage:
|
||||
input_tokens = usage.get("prompt_tokens") or usage.get(
|
||||
"input_tokens", 0
|
||||
)
|
||||
output_tokens = usage.get("completion_tokens") or usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
|
||||
# Get per-token costs from model info
|
||||
info = get_model_info(model=model)
|
||||
input_cost_per_token = info.get("input_cost_per_token", 0)
|
||||
output_cost_per_token = info.get("output_cost_per_token", 0)
|
||||
|
||||
input_cost = input_tokens * input_cost_per_token
|
||||
output_cost = output_tokens * output_cost_per_token
|
||||
|
||||
# Use total_cost from completion_cost if available, else calculate
|
||||
if not response:
|
||||
total_cost = input_cost + output_cost
|
||||
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"total_cost": total_cost,
|
||||
"currency": "USD",
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
# If we can't get pricing info, return None
|
||||
return None
|
||||
|
||||
def validate_environment(self, model: str) -> dict[str, Any]:
|
||||
"""
|
||||
Check if required environment variables are set for the model.
|
||||
Returns dict with 'keys_in_environment' (bool) and 'missing_keys' (list).
|
||||
"""
|
||||
try:
|
||||
result: dict[str, Any] = validate_environment(model=model)
|
||||
return result
|
||||
except Exception as e:
|
||||
# If validation fails, return a generic response
|
||||
return {
|
||||
"keys_in_environment": False,
|
||||
"missing_keys": ["API_KEY"],
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def test_connection(self, model: str) -> bool:
|
||||
"""Test if we can connect to the model"""
|
||||
|
||||
try:
|
||||
result = self.complete(model=model, prompt="Hello", max_tokens=5)
|
||||
return result.get("content") is not None
|
||||
except Exception:
|
||||
return False
|
||||
143
skills/consultant/scripts/model_selector.py
Normal file
143
skills/consultant/scripts/model_selector.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Model discovery and selection logic
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from litellm import model_cost
|
||||
|
||||
|
||||
class ModelSelector:
|
||||
"""Handles model discovery and automatic selection"""
|
||||
|
||||
@staticmethod
|
||||
def list_models(base_url: str | None = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Query available models.
|
||||
|
||||
Without base_url: Uses LiteLLM's model_cost dictionary for dynamic discovery
|
||||
With base_url: Calls proxy's /models or /v1/models endpoint
|
||||
"""
|
||||
|
||||
if not base_url:
|
||||
# Use LiteLLM's model_cost for dynamic discovery
|
||||
return ModelSelector._get_litellm_models()
|
||||
|
||||
# Try LiteLLM proxy /models endpoint first, then OpenAI-compatible /v1/models
|
||||
last_error = None
|
||||
for endpoint in ["/models", "/v1/models"]:
|
||||
try:
|
||||
models_url = f"{base_url.rstrip('/')}{endpoint}"
|
||||
response = requests.get(models_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
models = data.get("data", [])
|
||||
|
||||
return [
|
||||
{
|
||||
"id": m.get("id"),
|
||||
"created": m.get("created"),
|
||||
"owned_by": m.get("owned_by"),
|
||||
}
|
||||
for m in models
|
||||
]
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
continue
|
||||
|
||||
# If all endpoints fail, raise an error
|
||||
raise RuntimeError(f"Could not fetch models from {base_url}: {last_error}")
|
||||
|
||||
@staticmethod
|
||||
def select_best_model(base_url: str | None = None) -> str:
|
||||
"""
|
||||
Automatically select the best available model.
|
||||
Heuristic: Prefer models with "large", "pro", or higher version numbers
|
||||
"""
|
||||
|
||||
models = ModelSelector.list_models(base_url)
|
||||
|
||||
if not models:
|
||||
raise RuntimeError("No models available - cannot auto-select model")
|
||||
|
||||
# Score models based on name heuristics
|
||||
best_model = max(models, key=ModelSelector._score_model)
|
||||
model_id = best_model.get("id")
|
||||
if not model_id:
|
||||
raise RuntimeError("Best model has no id - cannot auto-select model")
|
||||
return str(model_id)
|
||||
|
||||
@staticmethod
|
||||
def _score_model(model: dict[str, Any]) -> float:
|
||||
"""Score a model based on capabilities (higher is better)"""
|
||||
|
||||
model_id = model.get("id", "").lower()
|
||||
score = 0.0
|
||||
|
||||
# Version number scoring
|
||||
if "gpt-5" in model_id or "o1" in model_id or "o3" in model_id:
|
||||
score += 50
|
||||
elif "gpt-4" in model_id:
|
||||
score += 40
|
||||
elif "gpt-3.5" in model_id:
|
||||
score += 30
|
||||
|
||||
# Capability indicators
|
||||
if any(x in model_id for x in ["pro", "turbo", "large", "xl", "ultra"]):
|
||||
score += 20
|
||||
|
||||
# Context size indicators
|
||||
if "128k" in model_id or "200k" in model_id:
|
||||
score += 15
|
||||
elif "32k" in model_id:
|
||||
score += 12
|
||||
elif "16k" in model_id:
|
||||
score += 10
|
||||
|
||||
# Anthropic models
|
||||
if "claude" in model_id:
|
||||
if "opus" in model_id:
|
||||
score += 50
|
||||
elif "sonnet" in model_id:
|
||||
if "3.5" in model_id or "3-5" in model_id:
|
||||
score += 48
|
||||
else:
|
||||
score += 45
|
||||
elif "haiku" in model_id:
|
||||
score += 35
|
||||
|
||||
# Google models
|
||||
if "gemini" in model_id:
|
||||
if "2.0" in model_id or "2-0" in model_id:
|
||||
score += 45
|
||||
elif "pro" in model_id:
|
||||
score += 40
|
||||
|
||||
return score
|
||||
|
||||
@staticmethod
|
||||
def _get_litellm_models() -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get models from LiteLLM's model_cost dictionary.
|
||||
This provides dynamic model discovery without hardcoded lists.
|
||||
"""
|
||||
|
||||
if not model_cost:
|
||||
raise RuntimeError("LiteLLM model_cost is empty - cannot discover models")
|
||||
|
||||
# Convert model_cost dictionary to list format
|
||||
models = []
|
||||
for model_id, info in model_cost.items():
|
||||
models.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"provider": info.get("litellm_provider", "unknown"),
|
||||
"max_tokens": info.get("max_tokens"),
|
||||
"input_cost_per_token": info.get("input_cost_per_token"),
|
||||
"output_cost_per_token": info.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
return models
|
||||
646
skills/consultant/scripts/response_strategy.py
Normal file
646
skills/consultant/scripts/response_strategy.py
Normal file
@@ -0,0 +1,646 @@
|
||||
"""
|
||||
Response strategies for different LLM providers.
|
||||
Handles retries, background jobs, and provider-specific quirks.
|
||||
Automatically detects responses API vs completions API support.
|
||||
"""
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import _should_retry, completion, responses
|
||||
|
||||
import config
|
||||
|
||||
|
||||
def _is_responses_api_model(model_name: str) -> bool:
|
||||
"""
|
||||
Check if a model name indicates responses API support.
|
||||
|
||||
Uses general patterns that will work for future model versions:
|
||||
- GPT-4+ (gpt-4, gpt-5, gpt-6, etc.)
|
||||
- O-series reasoning models (o1, o2, o3, o4, etc.)
|
||||
- Codex models
|
||||
- Computer-use models
|
||||
|
||||
Args:
|
||||
model_name: Model name without provider prefix (lowercase)
|
||||
|
||||
Returns:
|
||||
True if model should use responses API
|
||||
"""
|
||||
import re
|
||||
|
||||
# GPT-4 and above (gpt-4, gpt-5, gpt-6, etc. but not gpt-3.5)
|
||||
# Matches: gpt-4, gpt4, gpt-4-turbo, gpt-5.1, gpt-6-turbo, etc.
|
||||
gpt_match = re.search(r"gpt-?(\d+)", model_name)
|
||||
if gpt_match:
|
||||
version = int(gpt_match.group(1))
|
||||
if version >= 4:
|
||||
return True
|
||||
|
||||
# O-series reasoning models (o1, o2, o3, o4, etc.)
|
||||
# Matches: o1, o1-pro, o3-mini, o4-preview, etc.
|
||||
if re.search(r"\bo\d+\b", model_name) or re.search(r"\bo\d+-", model_name):
|
||||
return True
|
||||
|
||||
# Codex models (use responses API)
|
||||
if "codex" in model_name:
|
||||
return True
|
||||
|
||||
# Computer-use models
|
||||
return "computer-use" in model_name
|
||||
|
||||
|
||||
def get_responses_api_models() -> set[str]:
|
||||
"""
|
||||
Determine which models support the native OpenAI Responses API.
|
||||
|
||||
Uses litellm.models_by_provider to get OpenAI models, then filters
|
||||
to those that support the responses API.
|
||||
|
||||
Returns:
|
||||
Set of model identifiers that support the responses API natively.
|
||||
"""
|
||||
responses_models: set[str] = set()
|
||||
|
||||
# Get OpenAI models from litellm
|
||||
openai_models = litellm.models_by_provider.get("openai", [])
|
||||
azure_models = litellm.models_by_provider.get("azure", [])
|
||||
|
||||
for model in openai_models + azure_models:
|
||||
if _is_responses_api_model(model.lower()):
|
||||
responses_models.add(model)
|
||||
responses_models.add(f"openai/{model}")
|
||||
responses_models.add(f"azure/{model}")
|
||||
|
||||
return responses_models
|
||||
|
||||
|
||||
def supports_responses_api(model: str) -> bool:
|
||||
"""
|
||||
Check if a model supports the native OpenAI Responses API.
|
||||
|
||||
Uses general patterns that work for current and future models:
|
||||
- GPT-4+ series (gpt-4, gpt-5, gpt-6, etc.)
|
||||
- O-series reasoning models (o1, o2, o3, etc.)
|
||||
- Codex models
|
||||
- Computer-use models
|
||||
|
||||
Args:
|
||||
model: Model identifier (e.g., "openai/gpt-4", "gpt-5-mini")
|
||||
|
||||
Returns:
|
||||
True if model supports responses API natively, False otherwise.
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
# Extract model name and provider
|
||||
if "/" in model_lower:
|
||||
provider, model_name = model_lower.split("/", 1)
|
||||
else:
|
||||
provider = "openai" # Default provider for bare model names
|
||||
model_name = model_lower
|
||||
|
||||
# Only OpenAI and Azure support the responses API natively
|
||||
if provider not in ("openai", "azure"):
|
||||
return False
|
||||
|
||||
# Use the generalized pattern matching
|
||||
return _is_responses_api_model(model_name)
|
||||
|
||||
|
||||
class ResponseStrategy(ABC):
|
||||
"""Base class for response strategies"""
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
session_dir: Path | None = None,
|
||||
multimodal_content: list[dict[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute LLM request with provider-specific strategy.
|
||||
Returns dict with 'content' and optional 'usage'.
|
||||
|
||||
Args:
|
||||
model: Model identifier
|
||||
prompt: Text prompt
|
||||
session_dir: Optional session directory for state persistence
|
||||
multimodal_content: Optional multimodal content array for images
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def can_resume(self) -> bool:
|
||||
"""Whether this strategy supports resuming after failure"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _calculate_backoff_delay(
|
||||
self, attempt: int, base_delay: int, max_delay: int
|
||||
) -> float:
|
||||
"""Calculate exponential backoff delay with jitter"""
|
||||
import random
|
||||
|
||||
delay = min(base_delay * (2**attempt), max_delay)
|
||||
# Add 10% jitter to avoid thundering herd
|
||||
jitter = delay * 0.1 * random.random()
|
||||
return float(delay + jitter)
|
||||
|
||||
def _extract_content(self, response: Any) -> str:
|
||||
"""
|
||||
Extract text content from response.output structure.
|
||||
|
||||
Handles different output item types:
|
||||
- ResponseOutputMessage (type='message'): has content with text
|
||||
- ResponseReasoningItem (type='reasoning'): has summary, no content
|
||||
"""
|
||||
content = ""
|
||||
if hasattr(response, "output") and response.output:
|
||||
for item in response.output:
|
||||
# Check item type - only 'message' type has content
|
||||
item_type = getattr(item, "type", None)
|
||||
|
||||
if item_type == "message":
|
||||
# ResponseOutputMessage: extract text from content
|
||||
if hasattr(item, "content") and item.content:
|
||||
for content_item in item.content:
|
||||
if hasattr(content_item, "text"):
|
||||
content += content_item.text
|
||||
# Skip 'reasoning' items (ResponseReasoningItem) - they have summary, not content
|
||||
return content
|
||||
|
||||
def _serialize_usage(self, usage: Any) -> dict[str, Any] | None:
|
||||
"""
|
||||
Safely convert usage object to a JSON-serializable dict.
|
||||
Handles Pydantic models (OpenAI), dataclasses, and plain dicts.
|
||||
"""
|
||||
if usage is None:
|
||||
return None
|
||||
|
||||
# Already a dict - return as-is
|
||||
if isinstance(usage, dict):
|
||||
return dict(usage)
|
||||
|
||||
# Pydantic v2 model
|
||||
if hasattr(usage, "model_dump"):
|
||||
result: dict[str, Any] = usage.model_dump()
|
||||
return result
|
||||
|
||||
# Pydantic v1 model
|
||||
if hasattr(usage, "dict"):
|
||||
result = usage.dict()
|
||||
return dict(result)
|
||||
|
||||
# Dataclass or object with __dict__
|
||||
if hasattr(usage, "__dict__"):
|
||||
return dict(usage.__dict__)
|
||||
|
||||
# Last resort - try to convert directly
|
||||
try:
|
||||
return dict(usage)
|
||||
except (TypeError, ValueError):
|
||||
# If all else fails, return None rather than crash
|
||||
return None
|
||||
|
||||
|
||||
class BackgroundJobStrategy(ResponseStrategy):
|
||||
"""
|
||||
For OpenAI/Azure - uses background jobs with response_id polling.
|
||||
Supports resuming after network failures by persisting response_id.
|
||||
"""
|
||||
|
||||
def _convert_to_responses_api_format(
|
||||
self, multimodal_content: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert multimodal content from Completions API format to Responses API format.
|
||||
|
||||
Completions format: [{"type": "text/image_url", ...}]
|
||||
Responses format: [{"type": "input_text/input_image", ...}]
|
||||
"""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in multimodal_content:
|
||||
item_type = item.get("type", "")
|
||||
if item_type == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item_type == "image_url":
|
||||
# Extract URL from nested object
|
||||
image_url = item.get("image_url", {})
|
||||
url = image_url.get("url", "") if isinstance(image_url, dict) else ""
|
||||
converted.append({"type": "input_image", "image_url": url})
|
||||
return converted
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
session_dir: Path | None = None,
|
||||
multimodal_content: list[dict[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute with background job and polling"""
|
||||
|
||||
response_id_file = session_dir / "response_id.txt" if session_dir else None
|
||||
|
||||
# Check if we're resuming an existing background job
|
||||
if response_id_file and response_id_file.exists():
|
||||
response_id = response_id_file.read_text().strip()
|
||||
print(f"Resuming background job: {response_id}")
|
||||
return self._poll_for_completion(response_id)
|
||||
|
||||
# Build input - convert multimodal to Responses API format if provided
|
||||
input_content: str | list[dict[str, Any]]
|
||||
if multimodal_content:
|
||||
input_content = self._convert_to_responses_api_format(multimodal_content)
|
||||
else:
|
||||
input_content = prompt
|
||||
|
||||
# Start new background job
|
||||
try:
|
||||
response = responses(
|
||||
model=model,
|
||||
input=input_content,
|
||||
background=True, # Returns immediately with response_id
|
||||
num_retries=config.MAX_RETRIES, # Use LiteLLM's built-in retries
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
response_id = response.id
|
||||
|
||||
# Persist response_id for resumability
|
||||
if response_id_file:
|
||||
response_id_file.write_text(response_id)
|
||||
print(f"Started background job: {response_id}")
|
||||
|
||||
# Poll until complete
|
||||
return self._poll_for_completion(response_id)
|
||||
|
||||
except Exception as e:
|
||||
# If background mode fails, maybe not supported - raise for fallback
|
||||
raise RuntimeError(f"Background job failed to start: {e}") from e
|
||||
|
||||
def _poll_for_completion(self, response_id: str) -> dict[str, Any]:
|
||||
"""Poll for completion with exponential backoff and retries"""
|
||||
|
||||
start_time = time.time()
|
||||
attempt = 0
|
||||
|
||||
while time.time() - start_time < config.POLL_TIMEOUT:
|
||||
try:
|
||||
# Retrieve the response by ID
|
||||
result = litellm.get_response(response_id=response_id)
|
||||
|
||||
if hasattr(result, "status"):
|
||||
if result.status == "completed":
|
||||
content = self._extract_content(result)
|
||||
if not content:
|
||||
raise RuntimeError("No content in completed response")
|
||||
return {
|
||||
"content": content,
|
||||
"usage": self._serialize_usage(
|
||||
getattr(result, "usage", None)
|
||||
),
|
||||
"response": result, # Include full response for cost calculation
|
||||
}
|
||||
elif result.status == "failed":
|
||||
error = getattr(result, "error", "Unknown error")
|
||||
raise RuntimeError(f"Background job failed: {error}")
|
||||
elif result.status in ["in_progress", "queued"]:
|
||||
# Still processing, wait and retry
|
||||
time.sleep(config.POLL_INTERVAL)
|
||||
attempt += 1
|
||||
continue
|
||||
else:
|
||||
# Unknown status, wait and retry
|
||||
time.sleep(config.POLL_INTERVAL)
|
||||
continue
|
||||
else:
|
||||
# No status field - might be complete already
|
||||
content = self._extract_content(result)
|
||||
if content:
|
||||
return {
|
||||
"content": content,
|
||||
"usage": self._serialize_usage(
|
||||
getattr(result, "usage", None)
|
||||
),
|
||||
"response": result, # Include full response for cost calculation
|
||||
}
|
||||
# No content, wait and retry
|
||||
time.sleep(config.POLL_INTERVAL)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
|
||||
# Network errors - retry with backoff
|
||||
if any(x in error_msg for x in ["network", "timeout", "connection"]):
|
||||
if attempt < config.MAX_RETRIES:
|
||||
delay = self._calculate_backoff_delay(
|
||||
attempt, config.INITIAL_RETRY_DELAY, config.MAX_RETRY_DELAY
|
||||
)
|
||||
print(
|
||||
f"Network error polling job, retrying in {delay:.1f}s... (attempt {attempt + 1}/{config.MAX_RETRIES})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
attempt += 1
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Network errors exceeded max retries: {e}"
|
||||
) from e
|
||||
|
||||
# Other errors - raise immediately
|
||||
raise
|
||||
|
||||
raise TimeoutError(
|
||||
f"Background job {response_id} did not complete within {config.POLL_TIMEOUT}s"
|
||||
)
|
||||
|
||||
def can_resume(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SyncRetryStrategy(ResponseStrategy):
|
||||
"""
|
||||
For OpenAI/Azure models using responses API - direct sync calls with retry logic.
|
||||
Cannot resume - must retry from scratch if it fails.
|
||||
"""
|
||||
|
||||
def _convert_to_responses_api_format(
|
||||
self, multimodal_content: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert multimodal content from Completions API format to Responses API format.
|
||||
|
||||
Completions format: [{"type": "text/image_url", ...}]
|
||||
Responses format: [{"type": "input_text/input_image", ...}]
|
||||
"""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in multimodal_content:
|
||||
item_type = item.get("type", "")
|
||||
if item_type == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item_type == "image_url":
|
||||
# Extract URL from nested object
|
||||
image_url = item.get("image_url", {})
|
||||
url = image_url.get("url", "") if isinstance(image_url, dict) else ""
|
||||
converted.append({"type": "input_image", "image_url": url})
|
||||
return converted
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
session_dir: Path | None = None,
|
||||
multimodal_content: list[dict[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute with synchronous retries using responses API"""
|
||||
|
||||
# Build input - convert multimodal to Responses API format if provided
|
||||
input_content: str | list[dict[str, Any]]
|
||||
if multimodal_content:
|
||||
input_content = self._convert_to_responses_api_format(multimodal_content)
|
||||
else:
|
||||
input_content = prompt
|
||||
|
||||
for attempt in range(config.MAX_RETRIES):
|
||||
try:
|
||||
response = responses(
|
||||
model=model,
|
||||
input=input_content,
|
||||
stream=False,
|
||||
num_retries=config.MAX_RETRIES, # Use LiteLLM's built-in retries
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
content = self._extract_content(response)
|
||||
|
||||
if not content:
|
||||
raise RuntimeError("No content in response from LLM")
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"usage": self._serialize_usage(getattr(response, "usage", None)),
|
||||
"response": response, # Include full response for cost calculation
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Use LiteLLM's built-in retry logic for HTTP errors
|
||||
if _should_retry and hasattr(e, "status_code"):
|
||||
retryable = _should_retry(e.status_code)
|
||||
else:
|
||||
# Fallback to string matching for non-HTTP errors
|
||||
error_msg = str(e).lower()
|
||||
retryable = any(
|
||||
x in error_msg
|
||||
for x in [
|
||||
"network",
|
||||
"timeout",
|
||||
"connection",
|
||||
"429",
|
||||
"rate limit",
|
||||
"503",
|
||||
"overloaded",
|
||||
]
|
||||
)
|
||||
non_retryable = any(
|
||||
x in error_msg
|
||||
for x in [
|
||||
"auth",
|
||||
"key",
|
||||
"context",
|
||||
"token limit",
|
||||
"not found",
|
||||
"invalid",
|
||||
]
|
||||
)
|
||||
|
||||
if non_retryable:
|
||||
raise
|
||||
|
||||
if retryable and attempt < config.MAX_RETRIES - 1:
|
||||
delay = self._calculate_backoff_delay(
|
||||
attempt, config.INITIAL_RETRY_DELAY, config.MAX_RETRY_DELAY
|
||||
)
|
||||
print(
|
||||
f"Retryable error, waiting {delay:.1f}s before retry {attempt + 2}/{config.MAX_RETRIES}..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
raise
|
||||
|
||||
raise RuntimeError("Max retries exceeded")
|
||||
|
||||
def can_resume(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class CompletionsAPIStrategy(ResponseStrategy):
|
||||
"""
|
||||
For Anthropic/Google/other providers - uses chat completions API directly.
|
||||
More efficient than bridging through responses API for non-OpenAI providers.
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
session_dir: Path | None = None,
|
||||
multimodal_content: list[dict[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute with chat completions API"""
|
||||
|
||||
# Remove responses-specific kwargs that don't apply to completions
|
||||
kwargs.pop("reasoning_effort", None)
|
||||
kwargs.pop("background", None)
|
||||
|
||||
# Build message content - use multimodal content if provided, else plain prompt
|
||||
message_content: str | list[dict[str, Any]] = (
|
||||
multimodal_content if multimodal_content else prompt
|
||||
)
|
||||
|
||||
for attempt in range(config.MAX_RETRIES):
|
||||
try:
|
||||
# Use chat completions API
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": message_content}],
|
||||
stream=False,
|
||||
num_retries=config.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract content from chat completion response
|
||||
content = self._extract_completion_content(response)
|
||||
|
||||
if not content:
|
||||
raise RuntimeError("No content in response from LLM")
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"usage": self._serialize_usage(getattr(response, "usage", None)),
|
||||
"response": response,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Use LiteLLM's built-in retry logic for HTTP errors
|
||||
if _should_retry and hasattr(e, "status_code"):
|
||||
retryable = _should_retry(e.status_code)
|
||||
else:
|
||||
error_msg = str(e).lower()
|
||||
retryable = any(
|
||||
x in error_msg
|
||||
for x in [
|
||||
"network",
|
||||
"timeout",
|
||||
"connection",
|
||||
"429",
|
||||
"rate limit",
|
||||
"503",
|
||||
"overloaded",
|
||||
]
|
||||
)
|
||||
non_retryable = any(
|
||||
x in error_msg
|
||||
for x in [
|
||||
"auth",
|
||||
"key",
|
||||
"context",
|
||||
"token limit",
|
||||
"not found",
|
||||
"invalid",
|
||||
]
|
||||
)
|
||||
|
||||
if non_retryable:
|
||||
raise
|
||||
|
||||
if retryable and attempt < config.MAX_RETRIES - 1:
|
||||
delay = self._calculate_backoff_delay(
|
||||
attempt, config.INITIAL_RETRY_DELAY, config.MAX_RETRY_DELAY
|
||||
)
|
||||
print(
|
||||
f"Retryable error, waiting {delay:.1f}s before retry {attempt + 2}/{config.MAX_RETRIES}..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
raise
|
||||
|
||||
raise RuntimeError("Max retries exceeded")
|
||||
|
||||
def _extract_completion_content(self, response: Any) -> str:
|
||||
"""Extract text content from chat completions response"""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||
return choice.message.content or ""
|
||||
return ""
|
||||
|
||||
def can_resume(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ResponseStrategyFactory:
|
||||
"""Factory to select appropriate strategy based on model/provider and API support"""
|
||||
|
||||
# Models/providers that support background jobs (OpenAI Responses API feature)
|
||||
BACKGROUND_SUPPORTED = {
|
||||
"openai/",
|
||||
"azure/",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_strategy(model: str) -> ResponseStrategy:
|
||||
"""
|
||||
Select strategy based on model capabilities and API support.
|
||||
|
||||
Decision tree:
|
||||
1. If model supports responses API AND background jobs -> BackgroundJobStrategy
|
||||
2. If model supports responses API (no background) -> SyncRetryStrategy
|
||||
3. If model doesn't support responses API -> CompletionsAPIStrategy
|
||||
|
||||
Uses litellm.models_by_provider to determine support.
|
||||
"""
|
||||
# Check if model supports native responses API
|
||||
if supports_responses_api(model):
|
||||
# Check if it also supports background jobs
|
||||
if ResponseStrategyFactory.supports_background(model):
|
||||
return BackgroundJobStrategy()
|
||||
return SyncRetryStrategy()
|
||||
|
||||
# For all other providers (Anthropic, Google, Bedrock, etc.)
|
||||
# Use completions API directly - more efficient than bridging
|
||||
return CompletionsAPIStrategy()
|
||||
|
||||
@staticmethod
|
||||
def supports_background(model: str) -> bool:
|
||||
"""Check if model supports background job execution (OpenAI/Azure only)"""
|
||||
model_lower = model.lower()
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ResponseStrategyFactory.BACKGROUND_SUPPORTED
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_type(model: str) -> str:
|
||||
"""
|
||||
Determine which API type will be used for a given model.
|
||||
|
||||
Returns:
|
||||
'responses' for models using OpenAI Responses API
|
||||
'completions' for models using Chat Completions API
|
||||
"""
|
||||
if supports_responses_api(model):
|
||||
return "responses"
|
||||
return "completions"
|
||||
274
skills/consultant/scripts/session_manager.py
Normal file
274
skills/consultant/scripts/session_manager.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Session management for async consultant executions
|
||||
Handles background processes, session persistence, and status tracking
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import contextlib
|
||||
import json
|
||||
import multiprocessing
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import config
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages consultant sessions with async execution"""
|
||||
|
||||
def __init__(self, sessions_dir: Path | None = None) -> None:
|
||||
self.sessions_dir = sessions_dir or config.DEFAULT_SESSIONS_DIR
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
slug: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
reasoning_effort: str = "high",
|
||||
multimodal_content: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""Create a new session and start background execution"""
|
||||
|
||||
session_id = f"{slug}-{int(time.time())}"
|
||||
session_dir = self.sessions_dir / session_id
|
||||
session_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Save session metadata
|
||||
metadata = {
|
||||
"id": session_id,
|
||||
"slug": slug,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"status": "running",
|
||||
"model": model,
|
||||
"base_url": base_url,
|
||||
"reasoning_effort": reasoning_effort,
|
||||
"prompt_preview": prompt[:200] + "..." if len(prompt) > 200 else prompt,
|
||||
"has_images": multimodal_content is not None,
|
||||
}
|
||||
|
||||
metadata_file = session_dir / "metadata.json"
|
||||
metadata_file.write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
# Save full prompt
|
||||
prompt_file = session_dir / "prompt.txt"
|
||||
prompt_file.write_text(prompt)
|
||||
|
||||
# Start background process
|
||||
process = multiprocessing.Process(
|
||||
target=self._execute_session,
|
||||
args=(
|
||||
session_id,
|
||||
prompt,
|
||||
model,
|
||||
base_url,
|
||||
api_key,
|
||||
reasoning_effort,
|
||||
multimodal_content,
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
|
||||
# Store PID for potential cleanup
|
||||
(session_dir / "pid").write_text(str(process.pid))
|
||||
|
||||
return session_id
|
||||
|
||||
def _execute_session(
|
||||
self,
|
||||
session_id: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
base_url: str | None,
|
||||
api_key: str | None,
|
||||
reasoning_effort: str = "high",
|
||||
multimodal_content: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Background execution of LLM consultation"""
|
||||
|
||||
session_dir = self.sessions_dir / session_id
|
||||
|
||||
try:
|
||||
# Import here to avoid issues with multiprocessing
|
||||
from litellm_client import LiteLLMClient
|
||||
|
||||
# Initialize client
|
||||
client = LiteLLMClient(base_url=base_url, api_key=api_key)
|
||||
|
||||
# Make LLM call with the full prompt (already includes file contents)
|
||||
self._update_status(session_id, "calling_llm")
|
||||
|
||||
# Get full response (pass session_dir for resumability support)
|
||||
result = client.complete(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
session_dir=session_dir, # Enables background job resumption if supported
|
||||
reasoning_effort=reasoning_effort,
|
||||
multimodal_content=multimodal_content,
|
||||
)
|
||||
|
||||
full_response = result.get("content", "")
|
||||
usage = result.get("usage")
|
||||
response_obj = result.get(
|
||||
"response"
|
||||
) # Full response object for cost calculation
|
||||
|
||||
# Save response to file
|
||||
output_file = session_dir / "output.txt"
|
||||
output_file.write_text(full_response)
|
||||
|
||||
# Calculate cost using response object (preferred) or usage dict (fallback)
|
||||
cost_info = None
|
||||
if response_obj or usage:
|
||||
cost_info = client.calculate_cost(
|
||||
model, response=response_obj, usage=usage
|
||||
)
|
||||
|
||||
# Update metadata with usage and cost
|
||||
self._update_status(
|
||||
session_id,
|
||||
"completed",
|
||||
response=full_response,
|
||||
usage=usage,
|
||||
cost_info=cost_info,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error: {str(e)}\n\nType: {type(e).__name__}"
|
||||
(session_dir / "error.txt").write_text(error_msg)
|
||||
self._update_status(session_id, "error", error=error_msg)
|
||||
|
||||
def _update_status(
|
||||
self,
|
||||
session_id: str,
|
||||
status: str,
|
||||
response: str | None = None,
|
||||
error: str | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
cost_info: dict[str, Any] | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> None:
|
||||
"""Update session status in metadata"""
|
||||
|
||||
session_dir = self.sessions_dir / session_id
|
||||
metadata_file = session_dir / "metadata.json"
|
||||
|
||||
if not metadata_file.exists():
|
||||
return
|
||||
|
||||
metadata = json.loads(metadata_file.read_text())
|
||||
metadata["status"] = status
|
||||
metadata["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
if response:
|
||||
metadata["completed_at"] = datetime.now().isoformat()
|
||||
metadata["output_length"] = len(response)
|
||||
|
||||
if error:
|
||||
metadata["error"] = error[:500] # Truncate long errors
|
||||
|
||||
if usage:
|
||||
metadata["usage"] = usage
|
||||
|
||||
if cost_info:
|
||||
metadata["cost_info"] = cost_info
|
||||
|
||||
if reasoning_effort:
|
||||
metadata["reasoning_effort"] = reasoning_effort
|
||||
|
||||
metadata_file.write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
def get_session_status(self, slug: str) -> dict[str, Any]:
|
||||
"""Get current status of a session by slug"""
|
||||
|
||||
# Find most recent session with this slug
|
||||
matching_sessions = sorted(
|
||||
[
|
||||
d
|
||||
for d in self.sessions_dir.iterdir()
|
||||
if d.is_dir() and d.name.startswith(slug)
|
||||
],
|
||||
key=lambda x: x.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
if not matching_sessions:
|
||||
return {"error": f"No session found with slug: {slug}"}
|
||||
|
||||
session_dir = matching_sessions[0]
|
||||
metadata_file = session_dir / "metadata.json"
|
||||
|
||||
if not metadata_file.exists():
|
||||
return {"error": f"Session metadata not found: {slug}"}
|
||||
|
||||
metadata: dict[str, Any] = json.loads(metadata_file.read_text())
|
||||
|
||||
# Add output if completed
|
||||
if metadata["status"] == "completed":
|
||||
output_file = session_dir / "output.txt"
|
||||
if output_file.exists():
|
||||
metadata["output"] = output_file.read_text()
|
||||
|
||||
# Add error if failed
|
||||
if metadata["status"] == "error":
|
||||
error_file = session_dir / "error.txt"
|
||||
if error_file.exists():
|
||||
metadata["error_details"] = error_file.read_text()
|
||||
|
||||
return metadata
|
||||
|
||||
def wait_for_completion(
|
||||
self, session_id: str, timeout: int = 3600
|
||||
) -> dict[str, Any]:
|
||||
"""Block until session completes or timeout"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
session_dir = self.sessions_dir / session_id
|
||||
metadata_file = session_dir / "metadata.json"
|
||||
|
||||
if not metadata_file.exists():
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
metadata: dict[str, Any] = json.loads(metadata_file.read_text())
|
||||
|
||||
if metadata["status"] in ["completed", "error"]:
|
||||
# Add output if completed
|
||||
if metadata["status"] == "completed":
|
||||
output_file = session_dir / "output.txt"
|
||||
if output_file.exists():
|
||||
metadata["output"] = output_file.read_text()
|
||||
|
||||
# Add error if failed
|
||||
if metadata["status"] == "error":
|
||||
error_file = session_dir / "error.txt"
|
||||
if error_file.exists():
|
||||
metadata["error_details"] = error_file.read_text()
|
||||
|
||||
return metadata
|
||||
|
||||
time.sleep(config.POLLING_INTERVAL_SECONDS)
|
||||
|
||||
raise TimeoutError(f"Session {session_id} did not complete within {timeout}s")
|
||||
|
||||
def list_sessions(self) -> list[dict[str, Any]]:
|
||||
"""List all sessions"""
|
||||
|
||||
sessions = []
|
||||
for session_dir in self.sessions_dir.iterdir():
|
||||
if not session_dir.is_dir():
|
||||
continue
|
||||
|
||||
metadata_file = session_dir / "metadata.json"
|
||||
if metadata_file.exists():
|
||||
with contextlib.suppress(builtins.BaseException):
|
||||
sessions.append(json.loads(metadata_file.read_text()))
|
||||
|
||||
return sorted(sessions, key=lambda x: x.get("created_at", ""), reverse=True)
|
||||
Reference in New Issue
Block a user