Drop dep bumps + black-26 reformat to clear fork CI policy
PR was blocked by .github/workflows/guard-fork-dependencies.yml: fork PRs cannot modify uv.lock. Reverting: - uv.lock + pyproject.toml black bump (24.10.0 -> 26.3.1) and the 295 files of mechanical Black 26 reformat coupled to it - pyproject.toml diskcache extra change (kept the runtime mitigation in litellm/caching/disk_cache.py via JSONDisk) Kept: - Dockerfile cache narrowing (drops ~660 MB of uv build cache that surfaced cached setuptools as CVE findings) - litellm/caching/disk_cache.py: dc.JSONDisk to neutralize CVE-2025-69872 - ui/litellm-dashboard/package-lock.json + litellm-js/spend-logs/package-lock.json: next/postcss/hono/uuid CVE bumps (these are not blocked by the fork guard) - tests/test_litellm/caching/test_disk_cache.py - tests/code_coverage_tests/liccheck.ini: harmless black authorization Black + gitpython + langchain dep upgrades will need a follow-up from a maintainer pushing a branch in the canonical BerriAI/litellm repo. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
63bda3f001
commit
5bafa8b3a2
@ -2,7 +2,6 @@ import asyncio
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
# Asynchronously fetch data from a given URL
|
# Asynchronously fetch data from a given URL
|
||||||
async def fetch_data(url):
|
async def fetch_data(url):
|
||||||
try:
|
try:
|
||||||
@ -16,24 +15,22 @@ async def fetch_data(url):
|
|||||||
resp_json = await resp.json()
|
resp_json = await resp.json()
|
||||||
print("Fetch the data from URL.")
|
print("Fetch the data from URL.")
|
||||||
# Return the 'data' field from the JSON response
|
# Return the 'data' field from the JSON response
|
||||||
return resp_json["data"]
|
return resp_json['data']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Print an error message if fetching data fails
|
# Print an error message if fetching data fails
|
||||||
print("Error fetching data from URL:", e)
|
print("Error fetching data from URL:", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Synchronize local data with remote data
|
# Synchronize local data with remote data
|
||||||
def sync_local_data_with_remote(local_data, remote_data):
|
def sync_local_data_with_remote(local_data, remote_data):
|
||||||
# Update existing keys in local_data with values from remote_data
|
# Update existing keys in local_data with values from remote_data
|
||||||
for key in set(local_data) & set(remote_data):
|
for key in (set(local_data) & set(remote_data)):
|
||||||
local_data[key].update(remote_data[key])
|
local_data[key].update(remote_data[key])
|
||||||
|
|
||||||
# Add new keys from remote_data to local_data
|
# Add new keys from remote_data to local_data
|
||||||
for key in set(remote_data) - set(local_data):
|
for key in (set(remote_data) - set(local_data)):
|
||||||
local_data[key] = remote_data[key]
|
local_data[key] = remote_data[key]
|
||||||
|
|
||||||
|
|
||||||
# Write data to the json file
|
# Write data to the json file
|
||||||
def write_to_file(file_path, data):
|
def write_to_file(file_path, data):
|
||||||
try:
|
try:
|
||||||
@ -46,7 +43,6 @@ def write_to_file(file_path, data):
|
|||||||
# Print an error message if writing to file fails
|
# Print an error message if writing to file fails
|
||||||
print("Error updating JSON file:", e)
|
print("Error updating JSON file:", e)
|
||||||
|
|
||||||
|
|
||||||
# Update the existing models and add the missing models for OpenRouter
|
# Update the existing models and add the missing models for OpenRouter
|
||||||
def transform_openrouter_data(data):
|
def transform_openrouter_data(data):
|
||||||
transformed = {}
|
transformed = {}
|
||||||
@ -58,41 +54,33 @@ def transform_openrouter_data(data):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Add 'max_output_tokens' as a field if it is not None
|
# Add 'max_output_tokens' as a field if it is not None
|
||||||
if (
|
if "top_provider" in row and "max_completion_tokens" in row["top_provider"] and row["top_provider"]["max_completion_tokens"] is not None:
|
||||||
"top_provider" in row
|
obj['max_output_tokens'] = int(row["top_provider"]["max_completion_tokens"])
|
||||||
and "max_completion_tokens" in row["top_provider"]
|
|
||||||
and row["top_provider"]["max_completion_tokens"] is not None
|
|
||||||
):
|
|
||||||
obj["max_output_tokens"] = int(row["top_provider"]["max_completion_tokens"])
|
|
||||||
|
|
||||||
# Add the field 'output_cost_per_token'
|
# Add the field 'output_cost_per_token'
|
||||||
obj.update(
|
obj.update({
|
||||||
{
|
"output_cost_per_token": float(row["pricing"]["completion"]),
|
||||||
"output_cost_per_token": float(row["pricing"]["completion"]),
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add field 'input_cost_per_image' if it exists and is non-zero
|
# Add field 'input_cost_per_image' if it exists and is non-zero
|
||||||
if (
|
if "pricing" in row and "image" in row["pricing"] and float(row["pricing"]["image"]) != 0.0:
|
||||||
"pricing" in row
|
obj['input_cost_per_image'] = float(row["pricing"]["image"])
|
||||||
and "image" in row["pricing"]
|
|
||||||
and float(row["pricing"]["image"]) != 0.0
|
|
||||||
):
|
|
||||||
obj["input_cost_per_image"] = float(row["pricing"]["image"])
|
|
||||||
|
|
||||||
# Add the fields 'litellm_provider' and 'mode'
|
# Add the fields 'litellm_provider' and 'mode'
|
||||||
obj.update({"litellm_provider": "openrouter", "mode": "chat"})
|
obj.update({
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
})
|
||||||
|
|
||||||
# Add the 'supports_vision' field if the modality is 'multimodal'
|
# Add the 'supports_vision' field if the modality is 'multimodal'
|
||||||
if row.get("architecture", {}).get("modality") == "multimodal":
|
if row.get('architecture', {}).get('modality') == 'multimodal':
|
||||||
obj["supports_vision"] = True
|
obj['supports_vision'] = True
|
||||||
|
|
||||||
# Use a composite key to store the transformed object
|
# Use a composite key to store the transformed object
|
||||||
transformed[f'openrouter/{row["id"]}'] = obj
|
transformed[f'openrouter/{row["id"]}'] = obj
|
||||||
|
|
||||||
return transformed
|
return transformed
|
||||||
|
|
||||||
|
|
||||||
# Update the existing models and add the missing models for Vercel AI Gateway
|
# Update the existing models and add the missing models for Vercel AI Gateway
|
||||||
def transform_vercel_ai_gateway_data(data):
|
def transform_vercel_ai_gateway_data(data):
|
||||||
transformed = {}
|
transformed = {}
|
||||||
@ -101,30 +89,20 @@ def transform_vercel_ai_gateway_data(data):
|
|||||||
"max_tokens": row["context_window"],
|
"max_tokens": row["context_window"],
|
||||||
"input_cost_per_token": float(row["pricing"]["input"]),
|
"input_cost_per_token": float(row["pricing"]["input"]),
|
||||||
"output_cost_per_token": float(row["pricing"]["output"]),
|
"output_cost_per_token": float(row["pricing"]["output"]),
|
||||||
"max_output_tokens": row["max_tokens"],
|
'max_output_tokens': row['max_tokens'],
|
||||||
"max_input_tokens": row["context_window"],
|
'max_input_tokens': row["context_window"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle cache pricing if available
|
# Handle cache pricing if available
|
||||||
if "pricing" in row:
|
if "pricing" in row:
|
||||||
if (
|
if "input_cache_read" in row["pricing"] and row["pricing"]["input_cache_read"] is not None:
|
||||||
"input_cache_read" in row["pricing"]
|
obj['cache_read_input_token_cost'] = float(f"{float(row['pricing']['input_cache_read']):e}")
|
||||||
and row["pricing"]["input_cache_read"] is not None
|
|
||||||
):
|
if "input_cache_write" in row["pricing"] and row["pricing"]["input_cache_write"] is not None:
|
||||||
obj["cache_read_input_token_cost"] = float(
|
obj['cache_creation_input_token_cost'] = float(f"{float(row['pricing']['input_cache_write']):e}")
|
||||||
f"{float(row['pricing']['input_cache_read']):e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
"input_cache_write" in row["pricing"]
|
|
||||||
and row["pricing"]["input_cache_write"] is not None
|
|
||||||
):
|
|
||||||
obj["cache_creation_input_token_cost"] = float(
|
|
||||||
f"{float(row['pricing']['input_cache_write']):e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
mode = "embedding" if "embedding" in row["id"].lower() else "chat"
|
mode = "embedding" if "embedding" in row["id"].lower() else "chat"
|
||||||
|
|
||||||
obj.update({"litellm_provider": "vercel_ai_gateway", "mode": mode})
|
obj.update({"litellm_provider": "vercel_ai_gateway", "mode": mode})
|
||||||
|
|
||||||
transformed[f'vercel_ai_gateway/{row["id"]}'] = obj
|
transformed[f'vercel_ai_gateway/{row["id"]}'] = obj
|
||||||
@ -148,31 +126,24 @@ def load_local_data(file_path):
|
|||||||
print("Error decoding JSON:", e)
|
print("Error decoding JSON:", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
local_file_path = (
|
local_file_path = "model_prices_and_context_window.json" # Path to the local data file
|
||||||
"model_prices_and_context_window.json" # Path to the local data file
|
openrouter_url = "https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
|
||||||
)
|
vercel_ai_gateway_url = "https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
|
||||||
openrouter_url = (
|
|
||||||
"https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
|
|
||||||
)
|
|
||||||
vercel_ai_gateway_url = (
|
|
||||||
"https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load local data from file
|
# Load local data from file
|
||||||
local_data = load_local_data(local_file_path)
|
local_data = load_local_data(local_file_path)
|
||||||
|
|
||||||
# Fetch OpenRouter data
|
# Fetch OpenRouter data
|
||||||
openrouter_data = asyncio.run(fetch_data(openrouter_url))
|
openrouter_data = asyncio.run(fetch_data(openrouter_url))
|
||||||
# Transform the fetched OpenRouter data
|
# Transform the fetched OpenRouter data
|
||||||
openrouter_data = transform_openrouter_data(openrouter_data)
|
openrouter_data = transform_openrouter_data(openrouter_data)
|
||||||
|
|
||||||
# Fetch Vercel AI Gateway data
|
# Fetch Vercel AI Gateway data
|
||||||
vercel_data = asyncio.run(fetch_data(vercel_ai_gateway_url))
|
vercel_data = asyncio.run(fetch_data(vercel_ai_gateway_url))
|
||||||
# Transform the fetched Vercel AI Gateway data
|
# Transform the fetched Vercel AI Gateway data
|
||||||
vercel_data = transform_vercel_ai_gateway_data(vercel_data)
|
vercel_data = transform_vercel_ai_gateway_data(vercel_data)
|
||||||
|
|
||||||
# Combine both datasets
|
# Combine both datasets
|
||||||
all_remote_data = {**openrouter_data, **vercel_data}
|
all_remote_data = {**openrouter_data, **vercel_data}
|
||||||
|
|
||||||
@ -183,7 +154,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
print("Failed to fetch model data from either local file or URL.")
|
print("Failed to fetch model data from either local file or URL.")
|
||||||
|
|
||||||
|
|
||||||
# Entry point of the script
|
# Entry point of the script
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
512
.github/workflows/run_llm_translation_tests.py
vendored
512
.github/workflows/run_llm_translation_tests.py
vendored
@ -16,75 +16,64 @@ from pathlib import Path
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, List, Tuple, Optional
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
|
|
||||||
# ANSI color codes for terminal output
|
# ANSI color codes for terminal output
|
||||||
class Colors:
|
class Colors:
|
||||||
GREEN = "\033[92m"
|
GREEN = '\033[92m'
|
||||||
RED = "\033[91m"
|
RED = '\033[91m'
|
||||||
YELLOW = "\033[93m"
|
YELLOW = '\033[93m'
|
||||||
BLUE = "\033[94m"
|
BLUE = '\033[94m'
|
||||||
PURPLE = "\033[95m"
|
PURPLE = '\033[95m'
|
||||||
CYAN = "\033[96m"
|
CYAN = '\033[96m'
|
||||||
RESET = "\033[0m"
|
RESET = '\033[0m'
|
||||||
BOLD = "\033[1m"
|
BOLD = '\033[1m'
|
||||||
|
|
||||||
|
|
||||||
def print_colored(message: str, color: str = Colors.RESET):
|
def print_colored(message: str, color: str = Colors.RESET):
|
||||||
"""Print colored message to terminal"""
|
"""Print colored message to terminal"""
|
||||||
print(f"{color}{message}{Colors.RESET}")
|
print(f"{color}{message}{Colors.RESET}")
|
||||||
|
|
||||||
|
|
||||||
def get_provider_from_test_file(test_file: str) -> str:
|
def get_provider_from_test_file(test_file: str) -> str:
|
||||||
"""Map test file names to provider names"""
|
"""Map test file names to provider names"""
|
||||||
provider_mapping = {
|
provider_mapping = {
|
||||||
"test_anthropic": "Anthropic",
|
'test_anthropic': 'Anthropic',
|
||||||
"test_azure": "Azure",
|
'test_azure': 'Azure',
|
||||||
"test_bedrock": "AWS Bedrock",
|
'test_bedrock': 'AWS Bedrock',
|
||||||
"test_openai": "OpenAI",
|
'test_openai': 'OpenAI',
|
||||||
"test_vertex": "Google Vertex AI",
|
'test_vertex': 'Google Vertex AI',
|
||||||
"test_gemini": "Google Vertex AI",
|
'test_gemini': 'Google Vertex AI',
|
||||||
"test_cohere": "Cohere",
|
'test_cohere': 'Cohere',
|
||||||
"test_databricks": "Databricks",
|
'test_databricks': 'Databricks',
|
||||||
"test_groq": "Groq",
|
'test_groq': 'Groq',
|
||||||
"test_together": "Together AI",
|
'test_together': 'Together AI',
|
||||||
"test_mistral": "Mistral",
|
'test_mistral': 'Mistral',
|
||||||
"test_deepseek": "DeepSeek",
|
'test_deepseek': 'DeepSeek',
|
||||||
"test_replicate": "Replicate",
|
'test_replicate': 'Replicate',
|
||||||
"test_huggingface": "HuggingFace",
|
'test_huggingface': 'HuggingFace',
|
||||||
"test_fireworks": "Fireworks AI",
|
'test_fireworks': 'Fireworks AI',
|
||||||
"test_perplexity": "Perplexity",
|
'test_perplexity': 'Perplexity',
|
||||||
"test_cloudflare": "Cloudflare",
|
'test_cloudflare': 'Cloudflare',
|
||||||
"test_voyage": "Voyage AI",
|
'test_voyage': 'Voyage AI',
|
||||||
"test_xai": "xAI",
|
'test_xai': 'xAI',
|
||||||
"test_nvidia": "NVIDIA",
|
'test_nvidia': 'NVIDIA',
|
||||||
"test_watsonx": "IBM watsonx",
|
'test_watsonx': 'IBM watsonx',
|
||||||
"test_azure_ai": "Azure AI",
|
'test_azure_ai': 'Azure AI',
|
||||||
"test_snowflake": "Snowflake",
|
'test_snowflake': 'Snowflake',
|
||||||
"test_infinity": "Infinity",
|
'test_infinity': 'Infinity',
|
||||||
"test_jina": "Jina AI",
|
'test_jina': 'Jina AI',
|
||||||
"test_deepgram": "Deepgram",
|
'test_deepgram': 'Deepgram',
|
||||||
"test_clarifai": "Clarifai",
|
'test_clarifai': 'Clarifai',
|
||||||
"test_triton": "Triton",
|
'test_triton': 'Triton',
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, provider in provider_mapping.items():
|
for key, provider in provider_mapping.items():
|
||||||
if key in test_file:
|
if key in test_file:
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
# For cross-provider test files
|
# For cross-provider test files
|
||||||
if any(
|
if any(name in test_file for name in ['test_optional_params', 'test_prompt_factory',
|
||||||
name in test_file
|
'test_router', 'test_text_completion']):
|
||||||
for name in [
|
return f'Cross-Provider Tests ({test_file})'
|
||||||
"test_optional_params",
|
|
||||||
"test_prompt_factory",
|
return 'Other Tests'
|
||||||
"test_router",
|
|
||||||
"test_text_completion",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return f"Cross-Provider Tests ({test_file})"
|
|
||||||
|
|
||||||
return "Other Tests"
|
|
||||||
|
|
||||||
|
|
||||||
def format_duration(seconds: float) -> str:
|
def format_duration(seconds: float) -> str:
|
||||||
"""Format duration in human-readable format"""
|
"""Format duration in human-readable format"""
|
||||||
@ -100,355 +89,290 @@ def format_duration(seconds: float) -> str:
|
|||||||
return f"{hours}h {minutes}m"
|
return f"{hours}h {minutes}m"
|
||||||
|
|
||||||
|
|
||||||
def generate_markdown_report(
|
def generate_markdown_report(junit_xml_path: str, output_path: str, tag: str = None, commit: str = None):
|
||||||
junit_xml_path: str, output_path: str, tag: str = None, commit: str = None
|
|
||||||
):
|
|
||||||
"""Generate a beautiful markdown report from JUnit XML"""
|
"""Generate a beautiful markdown report from JUnit XML"""
|
||||||
try:
|
try:
|
||||||
tree = ET.parse(junit_xml_path)
|
tree = ET.parse(junit_xml_path)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
|
||||||
# Handle both testsuite and testsuites root
|
# Handle both testsuite and testsuites root
|
||||||
if root.tag == "testsuites":
|
if root.tag == 'testsuites':
|
||||||
suites = root.findall("testsuite")
|
suites = root.findall('testsuite')
|
||||||
else:
|
else:
|
||||||
suites = [root]
|
suites = [root]
|
||||||
|
|
||||||
# Overall statistics
|
# Overall statistics
|
||||||
total_tests = 0
|
total_tests = 0
|
||||||
total_failures = 0
|
total_failures = 0
|
||||||
total_errors = 0
|
total_errors = 0
|
||||||
total_skipped = 0
|
total_skipped = 0
|
||||||
total_time = 0.0
|
total_time = 0.0
|
||||||
|
|
||||||
# Provider breakdown
|
# Provider breakdown
|
||||||
provider_stats = defaultdict(
|
provider_stats = defaultdict(lambda: {'passed': 0, 'failed': 0, 'skipped': 0, 'errors': 0, 'time': 0.0})
|
||||||
lambda: {"passed": 0, "failed": 0, "skipped": 0, "errors": 0, "time": 0.0}
|
|
||||||
)
|
|
||||||
provider_tests = defaultdict(list)
|
provider_tests = defaultdict(list)
|
||||||
|
|
||||||
for suite in suites:
|
for suite in suites:
|
||||||
total_tests += int(suite.get("tests", 0))
|
total_tests += int(suite.get('tests', 0))
|
||||||
total_failures += int(suite.get("failures", 0))
|
total_failures += int(suite.get('failures', 0))
|
||||||
total_errors += int(suite.get("errors", 0))
|
total_errors += int(suite.get('errors', 0))
|
||||||
total_skipped += int(suite.get("skipped", 0))
|
total_skipped += int(suite.get('skipped', 0))
|
||||||
total_time += float(suite.get("time", 0))
|
total_time += float(suite.get('time', 0))
|
||||||
|
|
||||||
for testcase in suite.findall("testcase"):
|
for testcase in suite.findall('testcase'):
|
||||||
classname = testcase.get("classname", "")
|
classname = testcase.get('classname', '')
|
||||||
test_name = testcase.get("name", "")
|
test_name = testcase.get('name', '')
|
||||||
test_time = float(testcase.get("time", 0))
|
test_time = float(testcase.get('time', 0))
|
||||||
|
|
||||||
# Extract test file name from classname
|
# Extract test file name from classname
|
||||||
if "." in classname:
|
if '.' in classname:
|
||||||
parts = classname.split(".")
|
parts = classname.split('.')
|
||||||
test_file = parts[-2] if len(parts) > 1 else "unknown"
|
test_file = parts[-2] if len(parts) > 1 else 'unknown'
|
||||||
else:
|
else:
|
||||||
test_file = "unknown"
|
test_file = 'unknown'
|
||||||
|
|
||||||
provider = get_provider_from_test_file(test_file)
|
provider = get_provider_from_test_file(test_file)
|
||||||
provider_stats[provider]["time"] += test_time
|
provider_stats[provider]['time'] += test_time
|
||||||
|
|
||||||
# Check test status
|
# Check test status
|
||||||
if testcase.find("failure") is not None:
|
if testcase.find('failure') is not None:
|
||||||
provider_stats[provider]["failed"] += 1
|
provider_stats[provider]['failed'] += 1
|
||||||
failure = testcase.find("failure")
|
failure = testcase.find('failure')
|
||||||
failure_msg = (
|
failure_msg = failure.get('message', '') if failure is not None else ''
|
||||||
failure.get("message", "") if failure is not None else ""
|
provider_tests[provider].append({
|
||||||
)
|
'name': test_name,
|
||||||
provider_tests[provider].append(
|
'status': 'FAILED',
|
||||||
{
|
'time': test_time,
|
||||||
"name": test_name,
|
'message': failure_msg
|
||||||
"status": "FAILED",
|
})
|
||||||
"time": test_time,
|
elif testcase.find('error') is not None:
|
||||||
"message": failure_msg,
|
provider_stats[provider]['errors'] += 1
|
||||||
}
|
error = testcase.find('error')
|
||||||
)
|
error_msg = error.get('message', '') if error is not None else ''
|
||||||
elif testcase.find("error") is not None:
|
provider_tests[provider].append({
|
||||||
provider_stats[provider]["errors"] += 1
|
'name': test_name,
|
||||||
error = testcase.find("error")
|
'status': 'ERROR',
|
||||||
error_msg = error.get("message", "") if error is not None else ""
|
'time': test_time,
|
||||||
provider_tests[provider].append(
|
'message': error_msg
|
||||||
{
|
})
|
||||||
"name": test_name,
|
elif testcase.find('skipped') is not None:
|
||||||
"status": "ERROR",
|
provider_stats[provider]['skipped'] += 1
|
||||||
"time": test_time,
|
skip = testcase.find('skipped')
|
||||||
"message": error_msg,
|
skip_msg = skip.get('message', '') if skip is not None else ''
|
||||||
}
|
provider_tests[provider].append({
|
||||||
)
|
'name': test_name,
|
||||||
elif testcase.find("skipped") is not None:
|
'status': 'SKIPPED',
|
||||||
provider_stats[provider]["skipped"] += 1
|
'time': test_time,
|
||||||
skip = testcase.find("skipped")
|
'message': skip_msg
|
||||||
skip_msg = skip.get("message", "") if skip is not None else ""
|
})
|
||||||
provider_tests[provider].append(
|
|
||||||
{
|
|
||||||
"name": test_name,
|
|
||||||
"status": "SKIPPED",
|
|
||||||
"time": test_time,
|
|
||||||
"message": skip_msg,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
provider_stats[provider]["passed"] += 1
|
provider_stats[provider]['passed'] += 1
|
||||||
provider_tests[provider].append(
|
provider_tests[provider].append({
|
||||||
{
|
'name': test_name,
|
||||||
"name": test_name,
|
'status': 'PASSED',
|
||||||
"status": "PASSED",
|
'time': test_time,
|
||||||
"time": test_time,
|
'message': ''
|
||||||
"message": "",
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
passed = total_tests - total_failures - total_errors - total_skipped
|
passed = total_tests - total_failures - total_errors - total_skipped
|
||||||
|
|
||||||
# Generate the markdown report
|
# Generate the markdown report
|
||||||
with open(output_path, "w") as f:
|
with open(output_path, 'w') as f:
|
||||||
# Header
|
# Header
|
||||||
f.write("# LLM Translation Test Results\n\n")
|
f.write("# LLM Translation Test Results\n\n")
|
||||||
|
|
||||||
# Metadata table
|
# Metadata table
|
||||||
f.write("## Test Run Information\n\n")
|
f.write("## Test Run Information\n\n")
|
||||||
f.write("| Field | Value |\n")
|
f.write("| Field | Value |\n")
|
||||||
f.write("|-------|-------|\n")
|
f.write("|-------|-------|\n")
|
||||||
f.write(f"| **Tag** | `{tag or 'N/A'}` |\n")
|
f.write(f"| **Tag** | `{tag or 'N/A'}` |\n")
|
||||||
f.write(
|
f.write(f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n")
|
||||||
f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n"
|
|
||||||
)
|
|
||||||
f.write(f"| **Commit** | `{commit or 'N/A'}` |\n")
|
f.write(f"| **Commit** | `{commit or 'N/A'}` |\n")
|
||||||
f.write(f"| **Duration** | {format_duration(total_time)} |\n")
|
f.write(f"| **Duration** | {format_duration(total_time)} |\n")
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
|
||||||
# Overall statistics with visual elements
|
# Overall statistics with visual elements
|
||||||
f.write("## Overall Statistics\n\n")
|
f.write("## Overall Statistics\n\n")
|
||||||
|
|
||||||
# Summary box
|
# Summary box
|
||||||
f.write("```\n")
|
f.write("```\n")
|
||||||
f.write(f"Total Tests: {total_tests}\n")
|
f.write(f"Total Tests: {total_tests}\n")
|
||||||
f.write(
|
f.write(f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||||
f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
f.write(f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||||
)
|
f.write(f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||||
f.write(
|
f.write(f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||||
f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
|
||||||
)
|
|
||||||
f.write("```\n\n")
|
f.write("```\n\n")
|
||||||
|
|
||||||
|
|
||||||
# Provider summary table
|
# Provider summary table
|
||||||
f.write("## Results by Provider\n\n")
|
f.write("## Results by Provider\n\n")
|
||||||
f.write(
|
f.write("| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n")
|
||||||
"| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n"
|
f.write("|----------|-------|------|------|-------|------|-----------|----------|")
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
"|----------|-------|------|------|-------|------|-----------|----------|"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sort providers: specific providers first, then cross-provider tests
|
# Sort providers: specific providers first, then cross-provider tests
|
||||||
sorted_providers = []
|
sorted_providers = []
|
||||||
cross_provider = []
|
cross_provider = []
|
||||||
for p in sorted(provider_stats.keys()):
|
for p in sorted(provider_stats.keys()):
|
||||||
if "Cross-Provider" in p or p == "Other Tests":
|
if 'Cross-Provider' in p or p == 'Other Tests':
|
||||||
cross_provider.append(p)
|
cross_provider.append(p)
|
||||||
else:
|
else:
|
||||||
sorted_providers.append(p)
|
sorted_providers.append(p)
|
||||||
|
|
||||||
all_providers = sorted_providers + cross_provider
|
all_providers = sorted_providers + cross_provider
|
||||||
|
|
||||||
for provider in all_providers:
|
for provider in all_providers:
|
||||||
stats = provider_stats[provider]
|
stats = provider_stats[provider]
|
||||||
total = (
|
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||||
stats["passed"]
|
pass_rate = (stats['passed'] / total * 100) if total > 0 else 0
|
||||||
+ stats["failed"]
|
|
||||||
+ stats["errors"]
|
f.write(f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | ")
|
||||||
+ stats["skipped"]
|
|
||||||
)
|
|
||||||
pass_rate = (stats["passed"] / total * 100) if total > 0 else 0
|
|
||||||
|
|
||||||
f.write(
|
|
||||||
f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | "
|
|
||||||
)
|
|
||||||
f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ")
|
f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ")
|
||||||
f.write(f"{format_duration(stats['time'])} |")
|
f.write(f"{format_duration(stats['time'])} |")
|
||||||
|
|
||||||
# Detailed test results by provider
|
# Detailed test results by provider
|
||||||
f.write("\n\n## Detailed Test Results\n\n")
|
f.write("\n\n## Detailed Test Results\n\n")
|
||||||
|
|
||||||
for provider in sorted_providers:
|
for provider in sorted_providers:
|
||||||
if provider_tests[provider]:
|
if provider_tests[provider]:
|
||||||
stats = provider_stats[provider]
|
stats = provider_stats[provider]
|
||||||
total = (
|
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||||
stats["passed"]
|
|
||||||
+ stats["failed"]
|
|
||||||
+ stats["errors"]
|
|
||||||
+ stats["skipped"]
|
|
||||||
)
|
|
||||||
|
|
||||||
f.write(f"### {provider}\n\n")
|
f.write(f"### {provider}\n\n")
|
||||||
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
|
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
|
||||||
f.write(
|
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) ")
|
||||||
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) "
|
|
||||||
)
|
|
||||||
f.write(f"in {format_duration(stats['time'])}\n\n")
|
f.write(f"in {format_duration(stats['time'])}\n\n")
|
||||||
|
|
||||||
# Group tests by status
|
# Group tests by status
|
||||||
tests_by_status = defaultdict(list)
|
tests_by_status = defaultdict(list)
|
||||||
for test in provider_tests[provider]:
|
for test in provider_tests[provider]:
|
||||||
tests_by_status[test["status"]].append(test)
|
tests_by_status[test['status']].append(test)
|
||||||
|
|
||||||
# Show failed tests first (if any)
|
# Show failed tests first (if any)
|
||||||
if tests_by_status["FAILED"]:
|
if tests_by_status['FAILED']:
|
||||||
f.write("<details>\n<summary>Failed Tests</summary>\n\n")
|
f.write("<details>\n<summary>Failed Tests</summary>\n\n")
|
||||||
for test in tests_by_status["FAILED"]:
|
for test in tests_by_status['FAILED']:
|
||||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||||
if test["message"]:
|
if test['message']:
|
||||||
# Truncate long error messages
|
# Truncate long error messages
|
||||||
msg = (
|
msg = test['message'][:200] + '...' if len(test['message']) > 200 else test['message']
|
||||||
test["message"][:200] + "..."
|
|
||||||
if len(test["message"]) > 200
|
|
||||||
else test["message"]
|
|
||||||
)
|
|
||||||
f.write(f" > {msg}\n")
|
f.write(f" > {msg}\n")
|
||||||
f.write("\n</details>\n\n")
|
f.write("\n</details>\n\n")
|
||||||
|
|
||||||
# Show errors (if any)
|
# Show errors (if any)
|
||||||
if tests_by_status["ERROR"]:
|
if tests_by_status['ERROR']:
|
||||||
f.write("<details>\n<summary>Error Tests</summary>\n\n")
|
f.write("<details>\n<summary>Error Tests</summary>\n\n")
|
||||||
for test in tests_by_status["ERROR"]:
|
for test in tests_by_status['ERROR']:
|
||||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||||
f.write("\n</details>\n\n")
|
f.write("\n</details>\n\n")
|
||||||
|
|
||||||
# Show passed tests in collapsible section
|
# Show passed tests in collapsible section
|
||||||
if tests_by_status["PASSED"]:
|
if tests_by_status['PASSED']:
|
||||||
f.write("<details>\n<summary>Passed Tests</summary>\n\n")
|
f.write("<details>\n<summary>Passed Tests</summary>\n\n")
|
||||||
for test in tests_by_status["PASSED"]:
|
for test in tests_by_status['PASSED']:
|
||||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||||
f.write("\n</details>\n\n")
|
f.write("\n</details>\n\n")
|
||||||
|
|
||||||
# Show skipped tests (if any)
|
# Show skipped tests (if any)
|
||||||
if tests_by_status["SKIPPED"]:
|
if tests_by_status['SKIPPED']:
|
||||||
f.write("<details>\n<summary>Skipped Tests</summary>\n\n")
|
f.write("<details>\n<summary>Skipped Tests</summary>\n\n")
|
||||||
for test in tests_by_status["SKIPPED"]:
|
for test in tests_by_status['SKIPPED']:
|
||||||
f.write(f"- `{test['name']}`\n")
|
f.write(f"- `{test['name']}`\n")
|
||||||
f.write("\n</details>\n\n")
|
f.write("\n</details>\n\n")
|
||||||
|
|
||||||
# Cross-provider tests in a separate section
|
# Cross-provider tests in a separate section
|
||||||
if cross_provider:
|
if cross_provider:
|
||||||
f.write("### Cross-Provider Tests\n\n")
|
f.write("### Cross-Provider Tests\n\n")
|
||||||
for provider in cross_provider:
|
for provider in cross_provider:
|
||||||
if provider_tests[provider]:
|
if provider_tests[provider]:
|
||||||
stats = provider_stats[provider]
|
stats = provider_stats[provider]
|
||||||
total = (
|
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||||
stats["passed"]
|
|
||||||
+ stats["failed"]
|
|
||||||
+ stats["errors"]
|
|
||||||
+ stats["skipped"]
|
|
||||||
)
|
|
||||||
|
|
||||||
f.write(f"#### {provider}\n\n")
|
f.write(f"#### {provider}\n\n")
|
||||||
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
|
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
|
||||||
f.write(
|
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n")
|
||||||
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# For cross-provider tests, just show counts
|
# For cross-provider tests, just show counts
|
||||||
f.write(f"- Passed: {stats['passed']}\n")
|
f.write(f"- Passed: {stats['passed']}\n")
|
||||||
if stats["failed"] > 0:
|
if stats['failed'] > 0:
|
||||||
f.write(f"- Failed: {stats['failed']}\n")
|
f.write(f"- Failed: {stats['failed']}\n")
|
||||||
if stats["errors"] > 0:
|
if stats['errors'] > 0:
|
||||||
f.write(f"- Errors: {stats['errors']}\n")
|
f.write(f"- Errors: {stats['errors']}\n")
|
||||||
if stats["skipped"] > 0:
|
if stats['skipped'] > 0:
|
||||||
f.write(f"- Skipped: {stats['skipped']}\n")
|
f.write(f"- Skipped: {stats['skipped']}\n")
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
print_colored(f"Report generated: {output_path}", Colors.GREEN)
|
print_colored(f"Report generated: {output_path}", Colors.GREEN)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_colored(f"Error generating report: {e}", Colors.RED)
|
print_colored(f"Error generating report: {e}", Colors.RED)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def run_tests(test_path: str = "tests/llm_translation/",
|
||||||
def run_tests(
|
junit_xml: str = "test-results/junit.xml",
|
||||||
test_path: str = "tests/llm_translation/",
|
report_path: str = "test-results/llm_translation_report.md",
|
||||||
junit_xml: str = "test-results/junit.xml",
|
tag: str = None,
|
||||||
report_path: str = "test-results/llm_translation_report.md",
|
commit: str = None) -> int:
|
||||||
tag: str = None,
|
|
||||||
commit: str = None,
|
|
||||||
) -> int:
|
|
||||||
"""Run the LLM translation tests and generate report"""
|
"""Run the LLM translation tests and generate report"""
|
||||||
|
|
||||||
# Create test results directory
|
# Create test results directory
|
||||||
os.makedirs(os.path.dirname(junit_xml), exist_ok=True)
|
os.makedirs(os.path.dirname(junit_xml), exist_ok=True)
|
||||||
|
|
||||||
print_colored("Starting LLM Translation Tests", Colors.BOLD + Colors.BLUE)
|
print_colored("Starting LLM Translation Tests", Colors.BOLD + Colors.BLUE)
|
||||||
print_colored(f"Test directory: {test_path}", Colors.CYAN)
|
print_colored(f"Test directory: {test_path}", Colors.CYAN)
|
||||||
print_colored(f"Output: {junit_xml}", Colors.CYAN)
|
print_colored(f"Output: {junit_xml}", Colors.CYAN)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Run pytest
|
# Run pytest
|
||||||
cmd = [
|
cmd = [
|
||||||
"uv",
|
"uv", "run", "--no-sync", "pytest", test_path,
|
||||||
"run",
|
|
||||||
"--no-sync",
|
|
||||||
"pytest",
|
|
||||||
test_path,
|
|
||||||
f"--junitxml={junit_xml}",
|
f"--junitxml={junit_xml}",
|
||||||
"-v",
|
"-v",
|
||||||
"--tb=short",
|
"--tb=short",
|
||||||
"--maxfail=500",
|
"--maxfail=500",
|
||||||
"-n",
|
"-n", "auto"
|
||||||
"auto",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add timeout if pytest-timeout is installed
|
# Add timeout if pytest-timeout is installed
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
|
||||||
["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
|
capture_output=True, check=True)
|
||||||
capture_output=True,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
cmd.extend(["--timeout=300"])
|
cmd.extend(["--timeout=300"])
|
||||||
except:
|
except:
|
||||||
print_colored(
|
print_colored("Warning: pytest-timeout not installed, skipping timeout option", Colors.YELLOW)
|
||||||
"Warning: pytest-timeout not installed, skipping timeout option",
|
|
||||||
Colors.YELLOW,
|
|
||||||
)
|
|
||||||
|
|
||||||
print_colored("Running pytest with command:", Colors.YELLOW)
|
print_colored("Running pytest with command:", Colors.YELLOW)
|
||||||
print(f" {' '.join(cmd)}")
|
print(f" {' '.join(cmd)}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Run the tests
|
# Run the tests
|
||||||
result = subprocess.run(cmd, capture_output=False)
|
result = subprocess.run(cmd, capture_output=False)
|
||||||
|
|
||||||
# Generate the report regardless of test outcome
|
# Generate the report regardless of test outcome
|
||||||
if os.path.exists(junit_xml):
|
if os.path.exists(junit_xml):
|
||||||
print()
|
print()
|
||||||
print_colored("Generating test report...", Colors.BLUE)
|
print_colored("Generating test report...", Colors.BLUE)
|
||||||
generate_markdown_report(junit_xml, report_path, tag, commit)
|
generate_markdown_report(junit_xml, report_path, tag, commit)
|
||||||
|
|
||||||
# Print summary to console
|
# Print summary to console
|
||||||
print()
|
print()
|
||||||
print_colored("Test Summary:", Colors.BOLD + Colors.PURPLE)
|
print_colored("Test Summary:", Colors.BOLD + Colors.PURPLE)
|
||||||
|
|
||||||
# Parse XML for quick summary
|
# Parse XML for quick summary
|
||||||
tree = ET.parse(junit_xml)
|
tree = ET.parse(junit_xml)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
|
||||||
if root.tag == "testsuites":
|
if root.tag == 'testsuites':
|
||||||
suites = root.findall("testsuite")
|
suites = root.findall('testsuite')
|
||||||
else:
|
else:
|
||||||
suites = [root]
|
suites = [root]
|
||||||
|
|
||||||
total = sum(int(s.get("tests", 0)) for s in suites)
|
total = sum(int(s.get('tests', 0)) for s in suites)
|
||||||
failures = sum(int(s.get("failures", 0)) for s in suites)
|
failures = sum(int(s.get('failures', 0)) for s in suites)
|
||||||
errors = sum(int(s.get("errors", 0)) for s in suites)
|
errors = sum(int(s.get('errors', 0)) for s in suites)
|
||||||
skipped = sum(int(s.get("skipped", 0)) for s in suites)
|
skipped = sum(int(s.get('skipped', 0)) for s in suites)
|
||||||
passed = total - failures - errors - skipped
|
passed = total - failures - errors - skipped
|
||||||
|
|
||||||
print(f" Total: {total}")
|
print(f" Total: {total}")
|
||||||
print_colored(f" Passed: {passed}", Colors.GREEN)
|
print_colored(f" Passed: {passed}", Colors.GREEN)
|
||||||
if failures > 0:
|
if failures > 0:
|
||||||
@ -457,75 +381,59 @@ def run_tests(
|
|||||||
print_colored(f" Errors: {errors}", Colors.RED)
|
print_colored(f" Errors: {errors}", Colors.RED)
|
||||||
if skipped > 0:
|
if skipped > 0:
|
||||||
print_colored(f" Skipped: {skipped}", Colors.YELLOW)
|
print_colored(f" Skipped: {skipped}", Colors.YELLOW)
|
||||||
|
|
||||||
if total > 0:
|
if total > 0:
|
||||||
pass_rate = (passed / total) * 100
|
pass_rate = (passed / total) * 100
|
||||||
color = (
|
color = Colors.GREEN if pass_rate >= 80 else Colors.YELLOW if pass_rate >= 60 else Colors.RED
|
||||||
Colors.GREEN
|
|
||||||
if pass_rate >= 80
|
|
||||||
else Colors.YELLOW if pass_rate >= 60 else Colors.RED
|
|
||||||
)
|
|
||||||
print_colored(f" Pass Rate: {pass_rate:.1f}%", color)
|
print_colored(f" Pass Rate: {pass_rate:.1f}%", color)
|
||||||
else:
|
else:
|
||||||
print_colored("No test results found!", Colors.RED)
|
print_colored("No test results found!", Colors.RED)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print_colored("Test run complete!", Colors.BOLD + Colors.GREEN)
|
print_colored("Test run complete!", Colors.BOLD + Colors.GREEN)
|
||||||
|
|
||||||
return result.returncode
|
return result.returncode
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Run LLM Translation Tests")
|
parser = argparse.ArgumentParser(description="Run LLM Translation Tests")
|
||||||
parser.add_argument(
|
parser.add_argument("--test-path", default="tests/llm_translation/",
|
||||||
"--test-path", default="tests/llm_translation/", help="Path to test directory"
|
help="Path to test directory")
|
||||||
)
|
parser.add_argument("--junit-xml", default="test-results/junit.xml",
|
||||||
parser.add_argument(
|
help="Path for JUnit XML output")
|
||||||
"--junit-xml",
|
parser.add_argument("--report", default="test-results/llm_translation_report.md",
|
||||||
default="test-results/junit.xml",
|
help="Path for markdown report")
|
||||||
help="Path for JUnit XML output",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--report",
|
|
||||||
default="test-results/llm_translation_report.md",
|
|
||||||
help="Path for markdown report",
|
|
||||||
)
|
|
||||||
parser.add_argument("--tag", help="Git tag or version")
|
parser.add_argument("--tag", help="Git tag or version")
|
||||||
parser.add_argument("--commit", help="Git commit SHA")
|
parser.add_argument("--commit", help="Git commit SHA")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Get git info if not provided
|
# Get git info if not provided
|
||||||
if not args.commit:
|
if not args.commit:
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(["git", "rev-parse", "HEAD"],
|
||||||
["git", "rev-parse", "HEAD"], capture_output=True, text=True
|
capture_output=True, text=True)
|
||||||
)
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
args.commit = result.stdout.strip()
|
args.commit = result.stdout.strip()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if not args.tag:
|
if not args.tag:
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"],
|
||||||
["git", "describe", "--tags", "--abbrev=0"],
|
capture_output=True, text=True)
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
args.tag = result.stdout.strip()
|
args.tag = result.stdout.strip()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
exit_code = run_tests(
|
exit_code = run_tests(
|
||||||
test_path=args.test_path,
|
test_path=args.test_path,
|
||||||
junit_xml=args.junit_xml,
|
junit_xml=args.junit_xml,
|
||||||
report_path=args.report,
|
report_path=args.report,
|
||||||
tag=args.tag,
|
tag=args.tag,
|
||||||
commit=args.commit,
|
commit=args.commit
|
||||||
)
|
)
|
||||||
|
|
||||||
sys.exit(exit_code)
|
sys.exit(exit_code)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import testing.postgresql
|
import testing.postgresql
|
||||||
|
|
||||||
|
|
||||||
DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE)
|
DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE)
|
||||||
DEFAULT_BASE_BRANCH = "litellm_internal_staging"
|
DEFAULT_BASE_BRANCH = "litellm_internal_staging"
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from tabulate import tabulate
|
|||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
# Define the list of models to benchmark
|
# Define the list of models to benchmark
|
||||||
# select any LLM listed here: https://docs.litellm.ai/docs/providers
|
# select any LLM listed here: https://docs.litellm.ai/docs/providers
|
||||||
models = ["gpt-3.5-turbo", "claude-2"]
|
models = ["gpt-3.5-turbo", "claude-2"]
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from fastapi import FastAPI, HTTPException, Header, Query
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Braintrust Prompt Wrapper",
|
title="Braintrust Prompt Wrapper",
|
||||||
description="Wrapper server for Braintrust prompts to work with LiteLLM",
|
description="Wrapper server for Braintrust prompts to work with LiteLLM",
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
Simple xAI Voice Agent using LiveKit SDK with LiteLLM Gateway
|
Simple xAI Voice Agent using LiveKit SDK with LiteLLM Gateway
|
||||||
|
|
||||||
This example shows how to use LiveKit's xAI realtime plugin through LiteLLM proxy.
|
This example shows how to use LiveKit's xAI realtime plugin through LiteLLM proxy.
|
||||||
LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI,
|
LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI,
|
||||||
and Azure realtime APIs without changing your agent code.
|
and Azure realtime APIs without changing your agent code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
LiteLLM Migration Script!
|
LiteLLM Migration Script!
|
||||||
|
|
||||||
Takes a config.yaml and calls /model/new
|
Takes a config.yaml and calls /model/new
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- File path to config.yaml
|
- File path to config.yaml
|
||||||
- Proxy base url to your hosted proxy
|
- Proxy base url to your hosted proxy
|
||||||
|
|
||||||
Step 1: Reads your config.yaml
|
Step 1: Reads your config.yaml
|
||||||
Step 2: reads `model_list` and loops through all models
|
Step 2: reads `model_list` and loops through all models
|
||||||
Step 3: calls `<proxy-base-url>/model/new` for each model
|
Step 3: calls `<proxy-base-url>/model/new` for each model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -518,7 +518,8 @@ if __name__ == "__main__":
|
|||||||
print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply")
|
print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("\nExample curl command:")
|
print("\nExample curl command:")
|
||||||
print(f"""
|
print(
|
||||||
|
f"""
|
||||||
curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
|
curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
|
||||||
-H "Authorization: Bearer {bearer_token}" \\
|
-H "Authorization: Bearer {bearer_token}" \\
|
||||||
-H "Content-Type: application/json" \\
|
-H "Content-Type: application/json" \\
|
||||||
@ -532,7 +533,8 @@ curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
|
|||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
}}'
|
}}'
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|||||||
@ -34,7 +34,8 @@ async def check_view_exists(): # noqa: PLR0915
|
|||||||
print("LiteLLM_VerificationTokenView Exists!") # noqa
|
print("LiteLLM_VerificationTokenView Exists!") # noqa
|
||||||
except Exception:
|
except Exception:
|
||||||
# If an error occurs, the view does not exist, so create it
|
# If an error occurs, the view does not exist, so create it
|
||||||
await db.execute_raw("""
|
await db.execute_raw(
|
||||||
|
"""
|
||||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||||
SELECT
|
SELECT
|
||||||
v.*,
|
v.*,
|
||||||
@ -44,7 +45,8 @@ async def check_view_exists(): # noqa: PLR0915
|
|||||||
t.rpm_limit AS team_rpm_limit
|
t.rpm_limit AS team_rpm_limit
|
||||||
FROM "LiteLLM_VerificationToken" v
|
FROM "LiteLLM_VerificationToken" v
|
||||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
print("LiteLLM_VerificationTokenView Created!") # noqa
|
print("LiteLLM_VerificationTokenView Created!") # noqa
|
||||||
|
|
||||||
|
|||||||
@ -14,74 +14,53 @@ from litellm.types.utils import StandardCallbackDynamicParams
|
|||||||
class EnterpriseCallbackControls:
|
class EnterpriseCallbackControls:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_callback_disabled_dynamically(
|
def is_callback_disabled_dynamically(
|
||||||
callback: litellm.CALLBACK_TYPES,
|
callback: litellm.CALLBACK_TYPES,
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
standard_callback_dynamic_params: StandardCallbackDynamicParams
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
|
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback: The callback to check (can be string, CustomLogger instance, or callable)
|
callback: The callback to check (can be string, CustomLogger instance, or callable)
|
||||||
litellm_params: Parameters containing proxy server request info
|
litellm_params: Parameters containing proxy server request info
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the callback should be disabled, False otherwise
|
bool: True if the callback should be disabled, False otherwise
|
||||||
"""
|
"""
|
||||||
from litellm.litellm_core_utils.custom_logger_registry import (
|
from litellm.litellm_core_utils.custom_logger_registry import (
|
||||||
CustomLoggerRegistry,
|
CustomLoggerRegistry,
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(
|
|
||||||
litellm_params, standard_callback_dynamic_params
|
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
|
||||||
f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}"
|
|
||||||
)
|
|
||||||
verbose_logger.debug(
|
|
||||||
f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}"
|
|
||||||
)
|
|
||||||
if disabled_callbacks is not None:
|
|
||||||
#########################################################
|
|
||||||
# premium user check
|
|
||||||
#########################################################
|
|
||||||
if (
|
|
||||||
not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling()
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
#########################################################
|
|
||||||
if isinstance(callback, str):
|
|
||||||
if callback.lower() in disabled_callbacks:
|
|
||||||
verbose_logger.debug(
|
|
||||||
f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
elif isinstance(callback, CustomLogger):
|
|
||||||
# get the string name of the callback
|
|
||||||
callback_str = (
|
|
||||||
CustomLoggerRegistry.get_callback_str_from_class_type(
|
|
||||||
callback.__class__
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
callback_str is not None
|
|
||||||
and callback_str.lower() in disabled_callbacks
|
|
||||||
):
|
|
||||||
verbose_logger.debug(
|
|
||||||
f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.debug(f"Error checking disabled callbacks header: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
|
||||||
|
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
|
||||||
|
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
|
||||||
|
if disabled_callbacks is not None:
|
||||||
|
#########################################################
|
||||||
|
# premium user check
|
||||||
|
#########################################################
|
||||||
|
if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling():
|
||||||
|
return False
|
||||||
|
#########################################################
|
||||||
|
if isinstance(callback, str):
|
||||||
|
if callback.lower() in disabled_callbacks:
|
||||||
|
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
|
||||||
|
return True
|
||||||
|
elif isinstance(callback, CustomLogger):
|
||||||
|
# get the string name of the callback
|
||||||
|
callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__)
|
||||||
|
if callback_str is not None and callback_str.lower() in disabled_callbacks:
|
||||||
|
verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Error checking disabled callbacks header: {str(e)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_disabled_callbacks(
|
def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
|
||||||
litellm_params: dict,
|
|
||||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
|
||||||
) -> Optional[List[str]]:
|
|
||||||
"""
|
"""
|
||||||
Get the disabled callbacks from the standard callback dynamic params.
|
Get the disabled callbacks from the standard callback dynamic params.
|
||||||
"""
|
"""
|
||||||
@ -92,24 +71,18 @@ class EnterpriseCallbackControls:
|
|||||||
request_headers = get_proxy_server_request_headers(litellm_params)
|
request_headers = get_proxy_server_request_headers(litellm_params)
|
||||||
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
|
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
|
||||||
if disabled_callbacks is not None:
|
if disabled_callbacks is not None:
|
||||||
disabled_callbacks = set(
|
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
|
||||||
[cb.strip().lower() for cb in disabled_callbacks.split(",")]
|
|
||||||
)
|
|
||||||
return list(disabled_callbacks)
|
return list(disabled_callbacks)
|
||||||
|
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
# check if disabled via request body
|
# check if disabled via request body
|
||||||
#########################################################
|
#########################################################
|
||||||
if (
|
if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
|
||||||
standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
|
return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
|
||||||
is not None
|
|
||||||
):
|
|
||||||
return standard_callback_dynamic_params.get(
|
|
||||||
"litellm_disabled_callbacks", None
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_allow_dynamic_callback_disabling():
|
def _should_allow_dynamic_callback_disabling():
|
||||||
import litellm
|
import litellm
|
||||||
@ -117,14 +90,10 @@ class EnterpriseCallbackControls:
|
|||||||
|
|
||||||
# Check if admin has disabled this feature
|
# Check if admin has disabled this feature
|
||||||
if litellm.allow_dynamic_callback_disabling is not True:
|
if litellm.allow_dynamic_callback_disabling is not True:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling")
|
||||||
"Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling"
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if premium_user:
|
if premium_user:
|
||||||
return True
|
return True
|
||||||
verbose_logger.warning(
|
verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
|
||||||
f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}"
|
return False
|
||||||
)
|
|
||||||
return False
|
|
||||||
@ -349,10 +349,8 @@ class BaseEmailLogger(CustomLogger):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Calculate percentage and alert threshold
|
# Calculate percentage and alert threshold
|
||||||
percentage = (
|
percentage = threshold_pct if threshold_pct is not None else int(
|
||||||
threshold_pct
|
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100
|
||||||
if threshold_pct is not None
|
|
||||||
else int(EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100)
|
|
||||||
)
|
)
|
||||||
threshold_fraction = percentage / 100.0
|
threshold_fraction = percentage / 100.0
|
||||||
alert_threshold_str = (
|
alert_threshold_str = (
|
||||||
@ -611,7 +609,9 @@ class BaseEmailLogger(CustomLogger):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
_id = user_info.token or user_info.user_id or "default_id"
|
_id = user_info.token or user_info.user_id or "default_id"
|
||||||
_cache_key = f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
|
_cache_key = (
|
||||||
|
f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
|
||||||
|
)
|
||||||
|
|
||||||
result = await _cache.async_get_cache(key=_cache_key)
|
result = await _cache.async_get_cache(key=_cache_key)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
@ -630,9 +630,7 @@ class BaseEmailLogger(CustomLogger):
|
|||||||
continue
|
continue
|
||||||
recipient_emails = list(set(emails))
|
recipient_emails = list(set(emails))
|
||||||
|
|
||||||
event_message = (
|
event_message = f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
|
||||||
f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
|
|
||||||
)
|
|
||||||
webhook_event = WebhookEvent(
|
webhook_event = WebhookEvent(
|
||||||
event="max_budget_alert",
|
event="max_budget_alert",
|
||||||
event_message=event_message,
|
event_message=event_message,
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||||||
|
|
||||||
from .base_email import BaseEmailLogger
|
from .base_email import BaseEmailLogger
|
||||||
|
|
||||||
|
|
||||||
SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send"
|
SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send"
|
||||||
|
|
||||||
|
|
||||||
@ -78,4 +79,4 @@ class SendGridEmailLogger(BaseEmailLogger):
|
|||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"SendGrid response status={response.status_code}, body={response.text}"
|
f"SendGrid response status={response.status_code}, body={response.text}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
This is the litellm SMTP email integration
|
This is the litellm SMTP email integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Enterprise specific logging utils
|
Enterprise specific logging utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata
|
from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -153,11 +153,11 @@ async def get_audit_logs(
|
|||||||
|
|
||||||
# Return paginated response
|
# Return paginated response
|
||||||
return PaginatedAuditLogResponse(
|
return PaginatedAuditLogResponse(
|
||||||
audit_logs=(
|
audit_logs=[
|
||||||
[AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs]
|
AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs
|
||||||
if audit_logs
|
]
|
||||||
else []
|
if audit_logs
|
||||||
),
|
else [],
|
||||||
total=total_count,
|
total=total_count,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
|
|||||||
@ -7,4 +7,4 @@ including custom SSO handlers and advanced authentication features.
|
|||||||
|
|
||||||
from .custom_sso_handler import EnterpriseCustomSSOHandler
|
from .custom_sso_handler import EnterpriseCustomSSOHandler
|
||||||
|
|
||||||
__all__ = ["EnterpriseCustomSSOHandler"]
|
__all__ = ["EnterpriseCustomSSOHandler"]
|
||||||
@ -53,9 +53,7 @@ class CheckBatchCost:
|
|||||||
"user_api_key_alias": getattr(user_row, "user_alias", None),
|
"user_api_key_alias": getattr(user_row, "user_alias", None),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}")
|
||||||
f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}"
|
|
||||||
)
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _cleanup_stale_managed_objects(self) -> None:
|
async def _cleanup_stale_managed_objects(self) -> None:
|
||||||
@ -64,22 +62,11 @@ class CheckBatchCost:
|
|||||||
in non-terminal states as 'stale_expired'. These will never complete and
|
in non-terminal states as 'stale_expired'. These will never complete and
|
||||||
should not be polled.
|
should not be polled.
|
||||||
"""
|
"""
|
||||||
cutoff = datetime.now(timezone.utc) - timedelta(
|
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
|
||||||
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
|
|
||||||
)
|
|
||||||
result = await self.prisma_client.db.litellm_managedobjecttable.update_many(
|
result = await self.prisma_client.db.litellm_managedobjecttable.update_many(
|
||||||
where={
|
where={
|
||||||
"file_purpose": "batch",
|
"file_purpose": "batch",
|
||||||
"status": {
|
"status": {"not_in": ["completed", "complete", "failed", "expired", "cancelled", "stale_expired"]},
|
||||||
"not_in": [
|
|
||||||
"completed",
|
|
||||||
"complete",
|
|
||||||
"failed",
|
|
||||||
"expired",
|
|
||||||
"cancelled",
|
|
||||||
"stale_expired",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"created_at": {"lt": cutoff},
|
"created_at": {"lt": cutoff},
|
||||||
},
|
},
|
||||||
data={"status": "stale_expired"},
|
data={"status": "stale_expired"},
|
||||||
@ -133,12 +120,9 @@ class CheckBatchCost:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm.integrations.prometheus import PrometheusLogger
|
from litellm.integrations.prometheus import PrometheusLogger
|
||||||
|
|
||||||
prom_logger = PrometheusLogger.get_instance()
|
prom_logger = PrometheusLogger.get_instance()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(f"CheckBatchCost: could not get Prometheus logger: {e}")
|
||||||
f"CheckBatchCost: could not get Prometheus logger: {e}"
|
|
||||||
)
|
|
||||||
prom_logger = None
|
prom_logger = None
|
||||||
|
|
||||||
processed_models: List[Tuple[Optional[str], Optional[str]]] = []
|
processed_models: List[Tuple[Optional[str], Optional[str]]] = []
|
||||||
@ -177,11 +161,7 @@ class CheckBatchCost:
|
|||||||
order={"created_at": "asc"},
|
order={"created_at": "asc"},
|
||||||
)
|
)
|
||||||
except Exception as query_err:
|
except Exception as query_err:
|
||||||
if (
|
if "batch_processed" not in str(query_err).lower() and "unknown column" not in str(query_err).lower() and "does not exist" not in str(query_err).lower():
|
||||||
"batch_processed" not in str(query_err).lower()
|
|
||||||
and "unknown column" not in str(query_err).lower()
|
|
||||||
and "does not exist" not in str(query_err).lower()
|
|
||||||
):
|
|
||||||
raise
|
raise
|
||||||
# Permanent schema gap — cache the result so future cycles skip straight to fallback
|
# Permanent schema gap — cache the result so future cycles skip straight to fallback
|
||||||
self._has_batch_processed_column = False
|
self._has_batch_processed_column = False
|
||||||
@ -236,13 +216,14 @@ class CheckBatchCost:
|
|||||||
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
|
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
|
||||||
)
|
)
|
||||||
if prom_logger:
|
if prom_logger:
|
||||||
prom_logger.record_check_batch_cost_error(
|
prom_logger.record_check_batch_cost_error("provider_retrieval_error")
|
||||||
"provider_retrieval_error"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
## RETRIEVE THE BATCH JOB OUTPUT FILE
|
## RETRIEVE THE BATCH JOB OUTPUT FILE
|
||||||
if response.status == "completed" and response.output_file_id is not None:
|
if (
|
||||||
|
response.status == "completed"
|
||||||
|
and response.output_file_id is not None
|
||||||
|
):
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Batch ID: {batch_id} is complete, tracking cost and usage"
|
f"Batch ID: {batch_id} is complete, tracking cost and usage"
|
||||||
)
|
)
|
||||||
@ -269,25 +250,20 @@ class CheckBatchCost:
|
|||||||
decoded = _is_base64_encoded_unified_file_id(raw_output_file_id)
|
decoded = _is_base64_encoded_unified_file_id(raw_output_file_id)
|
||||||
if decoded:
|
if decoded:
|
||||||
try:
|
try:
|
||||||
raw_output_file_id = decoded.split("llm_output_file_id,")[
|
raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0]
|
||||||
1
|
|
||||||
].split(";")[0]
|
|
||||||
except (IndexError, AttributeError):
|
except (IndexError, AttributeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
credentials = (
|
credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {}
|
||||||
self.llm_router.get_deployment_credentials_with_provider(model_id)
|
|
||||||
or {}
|
|
||||||
)
|
|
||||||
_file_content = await afile_content(
|
_file_content = await afile_content(
|
||||||
file_id=raw_output_file_id,
|
file_id=raw_output_file_id,
|
||||||
**credentials,
|
**credentials,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Access content - handle both direct attribute and method call
|
# Access content - handle both direct attribute and method call
|
||||||
if hasattr(_file_content, "content"):
|
if hasattr(_file_content, 'content'):
|
||||||
content_bytes = _file_content.content # type: ignore[union-attr]
|
content_bytes = _file_content.content # type: ignore[union-attr]
|
||||||
elif hasattr(_file_content, "read"):
|
elif hasattr(_file_content, 'read'):
|
||||||
content_bytes = await _file_content.read() # type: ignore[misc]
|
content_bytes = await _file_content.read() # type: ignore[misc]
|
||||||
else:
|
else:
|
||||||
content_bytes = _file_content # type: ignore[assignment]
|
content_bytes = _file_content # type: ignore[assignment]
|
||||||
@ -314,9 +290,7 @@ class CheckBatchCost:
|
|||||||
f"Skipping job {unified_object_id} because it is not a valid deployment info"
|
f"Skipping job {unified_object_id} because it is not a valid deployment info"
|
||||||
)
|
)
|
||||||
if prom_logger:
|
if prom_logger:
|
||||||
prom_logger.record_check_batch_cost_error(
|
prom_logger.record_check_batch_cost_error("deployment_not_found")
|
||||||
"deployment_not_found"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
|
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
|
||||||
litellm_model_name = deployment_info.litellm_params.model
|
litellm_model_name = deployment_info.litellm_params.model
|
||||||
@ -328,11 +302,7 @@ class CheckBatchCost:
|
|||||||
|
|
||||||
# Pass deployment model_info so custom batch pricing
|
# Pass deployment model_info so custom batch pricing
|
||||||
# (input_cost_per_token_batches etc.) is used for cost calc
|
# (input_cost_per_token_batches etc.) is used for cost calc
|
||||||
deployment_model_info = (
|
deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}
|
||||||
deployment_info.model_info.model_dump()
|
|
||||||
if deployment_info.model_info
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
batch_cost, batch_usage, batch_models = (
|
batch_cost, batch_usage, batch_models = (
|
||||||
await calculate_batch_cost_and_usage(
|
await calculate_batch_cost_and_usage(
|
||||||
file_content_dictionary=file_content_as_dict,
|
file_content_dictionary=file_content_as_dict,
|
||||||
@ -379,9 +349,7 @@ class CheckBatchCost:
|
|||||||
|
|
||||||
# Record batch duration (completed_at - created_at)
|
# Record batch duration (completed_at - created_at)
|
||||||
if prom_logger and response.completed_at and response.created_at:
|
if prom_logger and response.completed_at and response.created_at:
|
||||||
duration_seconds = float(
|
duration_seconds = float(response.completed_at - response.created_at)
|
||||||
response.completed_at - response.created_at
|
|
||||||
)
|
|
||||||
if duration_seconds >= 0:
|
if duration_seconds >= 0:
|
||||||
prom_logger.record_managed_batch_duration(
|
prom_logger.record_managed_batch_duration(
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
@ -390,9 +358,7 @@ class CheckBatchCost:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Track this job for the final metrics summary
|
# Track this job for the final metrics summary
|
||||||
processed_models.append(
|
processed_models.append((model_name, str(llm_provider) if llm_provider else None))
|
||||||
(model_name, str(llm_provider) if llm_provider else None)
|
|
||||||
)
|
|
||||||
|
|
||||||
# mark the job as complete
|
# mark the job as complete
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -33,7 +33,9 @@ class CheckResponsesCost:
|
|||||||
self.prisma_client: PrismaClient = prisma_client
|
self.prisma_client: PrismaClient = prisma_client
|
||||||
self.llm_router: Router = llm_router
|
self.llm_router: Router = llm_router
|
||||||
|
|
||||||
async def _expire_stale_rows(self, cutoff: datetime, batch_size: int) -> int:
|
async def _expire_stale_rows(
|
||||||
|
self, cutoff: datetime, batch_size: int
|
||||||
|
) -> int:
|
||||||
"""Execute the bounded UPDATE that marks stale rows as 'stale_expired'.
|
"""Execute the bounded UPDATE that marks stale rows as 'stale_expired'.
|
||||||
|
|
||||||
Isolated so it can be swapped / mocked in tests without touching the
|
Isolated so it can be swapped / mocked in tests without touching the
|
||||||
@ -72,9 +74,7 @@ class CheckResponsesCost:
|
|||||||
rows per invocation to avoid overwhelming the DB when there is a large
|
rows per invocation to avoid overwhelming the DB when there is a large
|
||||||
backlog.
|
backlog.
|
||||||
"""
|
"""
|
||||||
cutoff = datetime.now(timezone.utc) - timedelta(
|
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
|
||||||
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
|
|
||||||
)
|
|
||||||
result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE)
|
result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE)
|
||||||
if result > 0:
|
if result > 0:
|
||||||
verbose_proxy_logger.warning(
|
verbose_proxy_logger.warning(
|
||||||
@ -105,7 +105,7 @@ class CheckResponsesCost:
|
|||||||
take=MAX_OBJECTS_PER_POLL_CYCLE,
|
take=MAX_OBJECTS_PER_POLL_CYCLE,
|
||||||
order={"created_at": "asc"},
|
order={"created_at": "asc"},
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"Found {len(jobs)} response jobs to check")
|
verbose_proxy_logger.debug(f"Found {len(jobs)} response jobs to check")
|
||||||
completed_jobs = []
|
completed_jobs = []
|
||||||
|
|
||||||
@ -120,33 +120,29 @@ class CheckResponsesCost:
|
|||||||
# Get the stored response object to extract model information
|
# Get the stored response object to extract model information
|
||||||
stored_response = job.file_object
|
stored_response = job.file_object
|
||||||
model_name = stored_response.get("model", None)
|
model_name = stored_response.get("model", None)
|
||||||
|
|
||||||
# Decrypt the response ID
|
# Decrypt the response ID
|
||||||
responses_id_security, _, _ = (
|
responses_id_security, _, _ = ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
|
||||||
ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare metadata with model information for cost tracking
|
# Prepare metadata with model information for cost tracking
|
||||||
litellm_metadata = {
|
litellm_metadata = {
|
||||||
"user_api_key_user_id": job.created_by or "default-user-id",
|
"user_api_key_user_id": job.created_by or "default-user-id",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add model information if available
|
# Add model information if available
|
||||||
if model_name:
|
if model_name:
|
||||||
litellm_metadata["model"] = model_name
|
litellm_metadata["model"] = model_name
|
||||||
litellm_metadata["model_group"] = (
|
litellm_metadata["model_group"] = model_name # Use same value for model_group
|
||||||
model_name # Use same value for model_group
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await litellm.aget_responses(
|
response = await litellm.aget_responses(
|
||||||
response_id=responses_id_security,
|
response_id=responses_id_security,
|
||||||
litellm_metadata=litellm_metadata,
|
litellm_metadata=litellm_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Response {unified_object_id} status: {response.status}, model: {model_name}"
|
f"Response {unified_object_id} status: {response.status}, model: {model_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Skipping job {unified_object_id} due to error: {e}"
|
f"Skipping job {unified_object_id} due to error: {e}"
|
||||||
@ -159,7 +155,7 @@ class CheckResponsesCost:
|
|||||||
f"Response {unified_object_id} is complete. Cost automatically tracked by aget_responses."
|
f"Response {unified_object_id} is complete. Cost automatically tracked by aget_responses."
|
||||||
)
|
)
|
||||||
completed_jobs.append(job)
|
completed_jobs.append(job)
|
||||||
|
|
||||||
elif response.status in ["failed", "cancelled"]:
|
elif response.status in ["failed", "cancelled"]:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Response {unified_object_id} has status {response.status}, marking as complete"
|
f"Response {unified_object_id} has status {response.status}, marking as complete"
|
||||||
@ -175,3 +171,4 @@ class CheckResponsesCost:
|
|||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Marked {len(completed_jobs)} response jobs as completed"
|
f"Marked {len(completed_jobs)} response jobs as completed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Managed vector stores with target_model_names support.
|
Managed vector stores with target_model_names support.
|
||||||
|
|
||||||
This class provides functionality to:
|
This class provides functionality to:
|
||||||
- Create vector stores across multiple models
|
- Create vector stores across multiple models
|
||||||
- Retrieve vector stores by unified ID
|
- Retrieve vector stores by unified ID
|
||||||
@ -77,14 +77,14 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate the format string for the unified vector store ID.
|
Generate the format string for the unified vector store ID.
|
||||||
|
|
||||||
Format:
|
Format:
|
||||||
litellm_proxy:vector_store;unified_id,<uuid>;target_model_names,<models>;resource_id,<vs_id>;model_id,<model_id>
|
litellm_proxy:vector_store;unified_id,<uuid>;target_model_names,<models>;resource_id,<vs_id>;model_id,<model_id>
|
||||||
"""
|
"""
|
||||||
# VectorStoreCreateResponse is a TypedDict, so resource_object is a dictionary
|
# VectorStoreCreateResponse is a TypedDict, so resource_object is a dictionary
|
||||||
# Extract provider resource ID from the response
|
# Extract provider resource ID from the response
|
||||||
provider_resource_id = resource_object.get("id", "")
|
provider_resource_id = resource_object.get("id", "")
|
||||||
|
|
||||||
# Model ID is stored in hidden params if the response object supports it
|
# Model ID is stored in hidden params if the response object supports it
|
||||||
# For TypedDict responses, we need to check if _hidden_params was added
|
# For TypedDict responses, we need to check if _hidden_params was added
|
||||||
hidden_params: Dict[str, Any] = {}
|
hidden_params: Dict[str, Any] = {}
|
||||||
@ -109,18 +109,20 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> VectorStoreCreateResponse:
|
) -> VectorStoreCreateResponse:
|
||||||
"""
|
"""
|
||||||
Create a vector store for a specific model.
|
Create a vector store for a specific model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm_router: LiteLLM router instance
|
llm_router: LiteLLM router instance
|
||||||
model: Model name to create vector store for
|
model: Model name to create vector store for
|
||||||
request_data: Request data for vector store creation
|
request_data: Request data for vector store creation
|
||||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
VectorStoreCreateResponse from the provider
|
VectorStoreCreateResponse from the provider
|
||||||
"""
|
"""
|
||||||
# Use the router to create the vector store
|
# Use the router to create the vector store
|
||||||
response = await llm_router.avector_store_create(model=model, **request_data)
|
response = await llm_router.avector_store_create(
|
||||||
|
model=model, **request_data
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@ -137,14 +139,14 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> VectorStoreCreateResponse:
|
) -> VectorStoreCreateResponse:
|
||||||
"""
|
"""
|
||||||
Create a vector store across multiple models.
|
Create a vector store across multiple models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
create_request: Vector store creation request parameters
|
create_request: Vector store creation request parameters
|
||||||
llm_router: LiteLLM router instance
|
llm_router: LiteLLM router instance
|
||||||
target_model_names_list: List of target model names
|
target_model_names_list: List of target model names
|
||||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||||
user_api_key_dict: User API key authentication details
|
user_api_key_dict: User API key authentication details
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
VectorStoreCreateResponse with unified ID
|
VectorStoreCreateResponse with unified ID
|
||||||
"""
|
"""
|
||||||
@ -194,7 +196,7 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
# VectorStoreCreateResponse is a TypedDict, so we need to create a new dict with the unified ID
|
# VectorStoreCreateResponse is a TypedDict, so we need to create a new dict with the unified ID
|
||||||
response = responses[0].copy()
|
response = responses[0].copy()
|
||||||
response["id"] = unified_id
|
response["id"] = unified_id
|
||||||
|
|
||||||
verbose_logger.info(
|
verbose_logger.info(
|
||||||
f"Successfully created managed vector store with unified ID: {unified_id}"
|
f"Successfully created managed vector store with unified ID: {unified_id}"
|
||||||
)
|
)
|
||||||
@ -210,13 +212,13 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
List vector stores created by a user.
|
List vector stores created by a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_api_key_dict: User API key authentication details
|
user_api_key_dict: User API key authentication details
|
||||||
limit: Maximum number of vector stores to return
|
limit: Maximum number of vector stores to return
|
||||||
after: Cursor for pagination
|
after: Cursor for pagination
|
||||||
order: Sort order ('asc' or 'desc')
|
order: Sort order ('asc' or 'desc')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with list of vector stores and pagination info
|
Dictionary with list of vector stores and pagination info
|
||||||
"""
|
"""
|
||||||
@ -236,23 +238,23 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if user has access to a vector store.
|
Check if user has access to a vector store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_store_id: The unified vector store ID
|
vector_store_id: The unified vector store ID
|
||||||
user_api_key_dict: User API key authentication details
|
user_api_key_dict: User API key authentication details
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if user has access, False otherwise
|
True if user has access, False otherwise
|
||||||
"""
|
"""
|
||||||
is_unified_id = is_base64_encoded_unified_id(vector_store_id)
|
is_unified_id = is_base64_encoded_unified_id(vector_store_id)
|
||||||
|
|
||||||
if is_unified_id:
|
if is_unified_id:
|
||||||
# Check access for managed vector store
|
# Check access for managed vector store
|
||||||
return await self.can_user_access_unified_resource_id(
|
return await self.can_user_access_unified_resource_id(
|
||||||
vector_store_id,
|
vector_store_id,
|
||||||
user_api_key_dict,
|
user_api_key_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Not a managed vector store, allow access
|
# Not a managed vector store, allow access
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -261,22 +263,24 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if user has access to a managed vector store in request data.
|
Check if user has access to a managed vector store in request data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Request data containing vector_store_id
|
data: Request data containing vector_store_id
|
||||||
user_api_key_dict: User API key authentication details
|
user_api_key_dict: User API key authentication details
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if this is a managed vector store and user has access
|
True if this is a managed vector store and user has access
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If user doesn't have access
|
HTTPException: If user doesn't have access
|
||||||
"""
|
"""
|
||||||
vector_store_id = cast(Optional[str], data.get("vector_store_id"))
|
vector_store_id = cast(Optional[str], data.get("vector_store_id"))
|
||||||
is_unified_id = (
|
is_unified_id = (
|
||||||
is_base64_encoded_unified_id(vector_store_id) if vector_store_id else False
|
is_base64_encoded_unified_id(vector_store_id)
|
||||||
|
if vector_store_id
|
||||||
|
else False
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_unified_id and vector_store_id:
|
if is_unified_id and vector_store_id:
|
||||||
if await self.can_user_access_unified_resource_id(
|
if await self.can_user_access_unified_resource_id(
|
||||||
vector_store_id, user_api_key_dict
|
vector_store_id, user_api_key_dict
|
||||||
@ -287,7 +291,7 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
|
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@ -303,18 +307,18 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> Union[Exception, str, Dict, None]:
|
) -> Union[Exception, str, Dict, None]:
|
||||||
"""
|
"""
|
||||||
Pre-call hook to handle vector store operations.
|
Pre-call hook to handle vector store operations.
|
||||||
|
|
||||||
This hook intercepts vector store requests and:
|
This hook intercepts vector store requests and:
|
||||||
- Validates access for managed vector stores
|
- Validates access for managed vector stores
|
||||||
- Transforms unified IDs to provider-specific IDs
|
- Transforms unified IDs to provider-specific IDs
|
||||||
- Adds model routing information
|
- Adds model routing information
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_api_key_dict: User API key authentication details
|
user_api_key_dict: User API key authentication details
|
||||||
cache: Cache instance
|
cache: Cache instance
|
||||||
data: Request data
|
data: Request data
|
||||||
call_type: Type of call being made
|
call_type: Type of call being made
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Modified request data or None
|
Modified request data or None
|
||||||
"""
|
"""
|
||||||
@ -326,40 +330,40 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
# Handle vector store search operations
|
# Handle vector store search operations
|
||||||
if call_type == "avector_store_search":
|
if call_type == "avector_store_search":
|
||||||
vector_store_id = data.get("vector_store_id")
|
vector_store_id = data.get("vector_store_id")
|
||||||
|
|
||||||
if vector_store_id:
|
if vector_store_id:
|
||||||
# Check if it's a managed vector store ID
|
# Check if it's a managed vector store ID
|
||||||
decoded_id = is_base64_encoded_unified_id(vector_store_id)
|
decoded_id = is_base64_encoded_unified_id(vector_store_id)
|
||||||
|
|
||||||
if decoded_id:
|
if decoded_id:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Processing managed vector store search: {vector_store_id}"
|
f"Processing managed vector store search: {vector_store_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check access
|
# Check access
|
||||||
has_access = await self.can_user_access_unified_resource_id(
|
has_access = await self.can_user_access_unified_resource_id(
|
||||||
vector_store_id, user_api_key_dict
|
vector_store_id, user_api_key_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
if not has_access:
|
if not has_access:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
|
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse the unified ID to extract components
|
# Parse the unified ID to extract components
|
||||||
parsed_id = parse_unified_id(vector_store_id)
|
parsed_id = parse_unified_id(vector_store_id)
|
||||||
|
|
||||||
if parsed_id:
|
if parsed_id:
|
||||||
# Extract the model ID and provider resource ID
|
# Extract the model ID and provider resource ID
|
||||||
model_id = parsed_id.get("model_id")
|
model_id = parsed_id.get("model_id")
|
||||||
provider_resource_id = parsed_id.get("provider_resource_id")
|
provider_resource_id = parsed_id.get("provider_resource_id")
|
||||||
target_model_names = parsed_id.get("target_model_names", [])
|
target_model_names = parsed_id.get("target_model_names", [])
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Decoded vector store - model_id: {model_id}, provider_resource_id: {provider_resource_id}, target_model_names: {target_model_names}"
|
f"Decoded vector store - model_id: {model_id}, provider_resource_id: {provider_resource_id}, target_model_names: {target_model_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine which model to use for routing
|
# Determine which model to use for routing
|
||||||
# Priority: model_id (deployment ID) > first target_model_name
|
# Priority: model_id (deployment ID) > first target_model_name
|
||||||
routing_model = None
|
routing_model = None
|
||||||
@ -367,28 +371,28 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
routing_model = model_id
|
routing_model = model_id
|
||||||
elif target_model_names and len(target_model_names) > 0:
|
elif target_model_names and len(target_model_names) > 0:
|
||||||
routing_model = target_model_names[0]
|
routing_model = target_model_names[0]
|
||||||
|
|
||||||
# Set the model for routing
|
# Set the model for routing
|
||||||
if routing_model:
|
if routing_model:
|
||||||
data["model"] = routing_model
|
data["model"] = routing_model
|
||||||
verbose_logger.info(
|
verbose_logger.info(
|
||||||
f"Routing vector store search to model: {routing_model}"
|
f"Routing vector store search to model: {routing_model}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace the unified ID with the provider-specific ID
|
# Replace the unified ID with the provider-specific ID
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
data["vector_store_id"] = provider_resource_id
|
data["vector_store_id"] = provider_resource_id
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Replaced unified ID with provider resource ID: {provider_resource_id}"
|
f"Replaced unified ID with provider resource ID: {provider_resource_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle vector store retrieve/delete operations
|
# Handle vector store retrieve/delete operations
|
||||||
elif call_type in ("avector_store_retrieve", "avector_store_delete"):
|
elif call_type in ("avector_store_retrieve", "avector_store_delete"):
|
||||||
await self.check_managed_vector_store_access(data, user_api_key_dict)
|
await self.check_managed_vector_store_access(data, user_api_key_dict)
|
||||||
|
|
||||||
# If it's a managed vector store, we'll handle it in the endpoint
|
# If it's a managed vector store, we'll handle it in the endpoint
|
||||||
# No need to transform here as the endpoint will route to the hook
|
# No need to transform here as the endpoint will route to the hook
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@ -403,15 +407,15 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Post-call hook to transform responses.
|
Post-call hook to transform responses.
|
||||||
|
|
||||||
This hook can be used to transform responses if needed.
|
This hook can be used to transform responses if needed.
|
||||||
For now, it just passes through the response.
|
For now, it just passes through the response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Request data
|
data: Request data
|
||||||
user_api_key_dict: User API key authentication details
|
user_api_key_dict: User API key authentication details
|
||||||
response: Response from the provider
|
response: Response from the provider
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Potentially modified response
|
Potentially modified response
|
||||||
"""
|
"""
|
||||||
@ -432,21 +436,21 @@ class _PROXY_LiteLLMManagedVectorStores(
|
|||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Filter deployments based on vector store availability.
|
Filter deployments based on vector store availability.
|
||||||
|
|
||||||
This is used by the router to select only deployments that have
|
This is used by the router to select only deployments that have
|
||||||
the vector store available.
|
the vector store available.
|
||||||
|
|
||||||
Note: This method signature is a compromise between CustomLogger and BaseManagedResource
|
Note: This method signature is a compromise between CustomLogger and BaseManagedResource
|
||||||
parent classes which have incompatible signatures. The type: ignore[override] is necessary
|
parent classes which have incompatible signatures. The type: ignore[override] is necessary
|
||||||
due to this multiple inheritance conflict.
|
due to this multiple inheritance conflict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model name
|
model: Model name
|
||||||
healthy_deployments: List of healthy deployments
|
healthy_deployments: List of healthy deployments
|
||||||
messages: Messages (unused for vector stores, required by CustomLogger interface)
|
messages: Messages (unused for vector stores, required by CustomLogger interface)
|
||||||
request_kwargs: Request kwargs containing vector_store_id and mappings
|
request_kwargs: Request kwargs containing vector_store_id and mappings
|
||||||
parent_otel_span: OpenTelemetry span for tracing
|
parent_otel_span: OpenTelemetry span for tracing
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Filtered list of deployments
|
Filtered list of deployments
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
Enterprise internal user management endpoints
|
Enterprise internal user management endpoints
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|||||||
@ -147,12 +147,12 @@ async def list_vector_stores(
|
|||||||
vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db(
|
vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db(
|
||||||
prisma_client=prisma_client
|
prisma_client=prisma_client
|
||||||
)
|
)
|
||||||
|
|
||||||
# Also clean up in-memory registry to remove any deleted vector stores
|
# Also clean up in-memory registry to remove any deleted vector stores
|
||||||
if litellm.vector_store_registry is not None:
|
if litellm.vector_store_registry is not None:
|
||||||
db_vector_store_ids = {
|
db_vector_store_ids = {
|
||||||
vs.get("vector_store_id")
|
vs.get("vector_store_id")
|
||||||
for vs in vector_stores_from_db
|
for vs in vector_stores_from_db
|
||||||
if vs.get("vector_store_id")
|
if vs.get("vector_store_id")
|
||||||
}
|
}
|
||||||
# Remove any in-memory vector stores that no longer exist in database
|
# Remove any in-memory vector stores that no longer exist in database
|
||||||
|
|||||||
@ -39,23 +39,15 @@ class EmailEvent(str, enum.Enum):
|
|||||||
soft_budget_crossed = "Soft Budget Crossed"
|
soft_budget_crossed = "Soft Budget Crossed"
|
||||||
max_budget_alert = "Max Budget Alert"
|
max_budget_alert = "Max Budget Alert"
|
||||||
|
|
||||||
|
|
||||||
class EmailEventSettings(BaseModel):
|
class EmailEventSettings(BaseModel):
|
||||||
event: EmailEvent
|
event: EmailEvent
|
||||||
enabled: bool
|
enabled: bool
|
||||||
|
|
||||||
|
|
||||||
class EmailEventSettingsUpdateRequest(BaseModel):
|
class EmailEventSettingsUpdateRequest(BaseModel):
|
||||||
settings: List[EmailEventSettings]
|
settings: List[EmailEventSettings]
|
||||||
|
|
||||||
|
|
||||||
class EmailEventSettingsResponse(BaseModel):
|
class EmailEventSettingsResponse(BaseModel):
|
||||||
settings: List[EmailEventSettings]
|
settings: List[EmailEventSettings]
|
||||||
|
|
||||||
|
|
||||||
class DefaultEmailSettings(BaseModel):
|
class DefaultEmailSettings(BaseModel):
|
||||||
"""Default settings for email events"""
|
"""Default settings for email events"""
|
||||||
|
|
||||||
settings: Dict[EmailEvent, bool] = Field(
|
settings: Dict[EmailEvent, bool] = Field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
EmailEvent.virtual_key_created: True, # On by default
|
EmailEvent.virtual_key_created: True, # On by default
|
||||||
@ -65,12 +57,10 @@ class DefaultEmailSettings(BaseModel):
|
|||||||
EmailEvent.max_budget_alert: True, # On by default
|
EmailEvent.max_budget_alert: True, # On by default
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, bool]:
|
def to_dict(self) -> Dict[str, bool]:
|
||||||
"""Convert to dictionary with string keys for storage"""
|
"""Convert to dictionary with string keys for storage"""
|
||||||
return {event.value: enabled for event, enabled in self.settings.items()}
|
return {event.value: enabled for event, enabled in self.settings.items()}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_defaults(cls) -> Dict[str, bool]:
|
def get_defaults(cls) -> Dict[str, bool]:
|
||||||
"""Get the default settings as a dictionary with string keys"""
|
"""Get the default settings as a dictionary with string keys"""
|
||||||
return cls().to_dict()
|
return cls().to_dict()
|
||||||
@ -6,6 +6,7 @@ Always uses fastuuid for performance.
|
|||||||
|
|
||||||
import fastuuid as _uuid # type: ignore
|
import fastuuid as _uuid # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# Expose a module-like alias so callers can use: uuid.uuid4()
|
# Expose a module-like alias so callers can use: uuid.uuid4()
|
||||||
uuid = _uuid
|
uuid = _uuid
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
|
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
|
||||||
|
|
||||||
|
|
||||||
# HTTP status code -> Anthropic error type
|
# HTTP status code -> Anthropic error type
|
||||||
# Source: https://docs.anthropic.com/en/api/errors
|
# Source: https://docs.anthropic.com/en/api/errors
|
||||||
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {
|
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
||||||
|
|
||||||
# Known Anthropic error types
|
# Known Anthropic error types
|
||||||
# Source: https://docs.anthropic.com/en/api/errors
|
# Source: https://docs.anthropic.com/en/api/errors
|
||||||
AnthropicErrorType = Literal[
|
AnthropicErrorType = Literal[
|
||||||
|
|||||||
@ -97,7 +97,7 @@ def _build_reasoning_item(
|
|||||||
|
|
||||||
|
|
||||||
def _reasoning_item_to_response_input(
|
def _reasoning_item_to_response_input(
|
||||||
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]],
|
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Convert a stored ChatCompletionReasoningItem back to a Responses API input item."""
|
"""Convert a stored ChatCompletionReasoningItem back to a Responses API input item."""
|
||||||
r_input: Dict[str, Any] = {
|
r_input: Dict[str, Any] = {
|
||||||
|
|||||||
@ -5,6 +5,7 @@ Auto-detect content type per message: code, JSON, or text.
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
_CODE_KEYWORDS = re.compile(
|
_CODE_KEYWORDS = re.compile(
|
||||||
r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b"
|
r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union
|
from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union
|
||||||
|
|
||||||
|
|
||||||
FileContentProvider = Literal[
|
FileContentProvider = Literal[
|
||||||
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
|
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Google GenAI Adapters for LiteLLM
|
Google GenAI Adapters for LiteLLM
|
||||||
|
|
||||||
This module provides adapters for transforming Google GenAI generate_content requests
|
This module provides adapters for transforming Google GenAI generate_content requests
|
||||||
to/from LiteLLM completion format with full support for:
|
to/from LiteLLM completion format with full support for:
|
||||||
- Text content transformation
|
- Text content transformation
|
||||||
- Tool calling (function declarations, function calls, function responses)
|
- Tool calling (function declarations, function calls, function responses)
|
||||||
- Streaming (both regular and tool calling)
|
- Streaming (both regular and tool calling)
|
||||||
- Mixed content (text + tool calls)
|
- Mixed content (text + tool calls)
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Handles Batching + sending Httpx Post requests to slack
|
Handles Batching + sending Httpx Post requests to slack
|
||||||
|
|
||||||
Slack alerts are sent every 10s or when events are greater than X events
|
Slack alerts are sent every 10s or when events are greater than X events
|
||||||
|
|
||||||
see custom_batch_logger.py for more details / defaults
|
see custom_batch_logger.py for more details / defaults
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|||||||
@ -18,7 +18,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def process_slack_alerting_variables(
|
def process_slack_alerting_variables(
|
||||||
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]],
|
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
|
||||||
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
|
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
|
||||||
"""
|
"""
|
||||||
process alert_to_webhook_url
|
process alert_to_webhook_url
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Base class for Additional Logging Utils for CustomLoggers
|
Base class for Additional Logging Utils for CustomLoggers
|
||||||
|
|
||||||
- Health Check for the logging util
|
- Health Check for the logging util
|
||||||
- Get Request / Response Payload for the logging util
|
- Get Request / Response Payload for the logging util
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Custom Logger that handles batching logic
|
Custom Logger that handles batching logic
|
||||||
|
|
||||||
Use this if you want your logs to be stored in memory and flushed periodically.
|
Use this if you want your logs to be stored in memory and flushed periodically.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import polars as pl
|
|||||||
|
|
||||||
from .schema import FOCUS_NORMALIZED_SCHEMA
|
from .schema import FOCUS_NORMALIZED_SCHEMA
|
||||||
|
|
||||||
|
|
||||||
_TAG_KEYS = (
|
_TAG_KEYS = (
|
||||||
"team_id",
|
"team_id",
|
||||||
"team_alias",
|
"team_alias",
|
||||||
|
|||||||
@ -105,7 +105,7 @@ def _remove_nulls(x: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def get_traces_and_spans_from_payload(
|
def get_traces_and_spans_from_payload(
|
||||||
payload: List[Dict[str, Any]],
|
payload: List[Dict[str, Any]]
|
||||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Separate traces and spans from payload.
|
Separate traces and spans from payload.
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
s3 Bucket Logging Integration
|
s3 Bucket Logging Integration
|
||||||
|
|
||||||
async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
|
async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
|
||||||
async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
|
async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
|
||||||
NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
|
NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -5,28 +5,28 @@ This module provides SDK methods for Google's Interactions API.
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
# Create an interaction with a model
|
# Create an interaction with a model
|
||||||
response = litellm.interactions.create(
|
response = litellm.interactions.create(
|
||||||
model="gemini-2.5-flash",
|
model="gemini-2.5-flash",
|
||||||
input="Hello, how are you?"
|
input="Hello, how are you?"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create an interaction with an agent
|
# Create an interaction with an agent
|
||||||
response = litellm.interactions.create(
|
response = litellm.interactions.create(
|
||||||
agent="deep-research-pro-preview-12-2025",
|
agent="deep-research-pro-preview-12-2025",
|
||||||
input="Research the current state of cancer research"
|
input="Research the current state of cancer research"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Async version
|
# Async version
|
||||||
response = await litellm.interactions.acreate(...)
|
response = await litellm.interactions.acreate(...)
|
||||||
|
|
||||||
# Get an interaction
|
# Get an interaction
|
||||||
response = litellm.interactions.get(interaction_id="...")
|
response = litellm.interactions.get(interaction_id="...")
|
||||||
|
|
||||||
# Delete an interaction
|
# Delete an interaction
|
||||||
result = litellm.interactions.delete(interaction_id="...")
|
result = litellm.interactions.delete(interaction_id="...")
|
||||||
|
|
||||||
# Cancel an interaction
|
# Cancel an interaction
|
||||||
result = litellm.interactions.cancel(interaction_id="...")
|
result = litellm.interactions.cancel(interaction_id="...")
|
||||||
|
|
||||||
|
|||||||
@ -8,25 +8,25 @@ Per OpenAPI spec (https://ai.google.dev/static/api/interactions.openapi.json):
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
# Create an interaction with a model
|
# Create an interaction with a model
|
||||||
response = litellm.interactions.create(
|
response = litellm.interactions.create(
|
||||||
model="gemini-2.5-flash",
|
model="gemini-2.5-flash",
|
||||||
input="Hello, how are you?"
|
input="Hello, how are you?"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create an interaction with an agent
|
# Create an interaction with an agent
|
||||||
response = litellm.interactions.create(
|
response = litellm.interactions.create(
|
||||||
agent="deep-research-pro-preview-12-2025",
|
agent="deep-research-pro-preview-12-2025",
|
||||||
input="Research the current state of cancer research"
|
input="Research the current state of cancer research"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Async version
|
# Async version
|
||||||
response = await litellm.interactions.acreate(...)
|
response = await litellm.interactions.acreate(...)
|
||||||
|
|
||||||
# Get an interaction
|
# Get an interaction
|
||||||
response = litellm.interactions.get(interaction_id="...")
|
response = litellm.interactions.get(interaction_id="...")
|
||||||
|
|
||||||
# Delete an interaction
|
# Delete an interaction
|
||||||
result = litellm.interactions.delete(interaction_id="...")
|
result = litellm.interactions.delete(interaction_id="...")
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -994,8 +994,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||||||
try:
|
try:
|
||||||
# [Non-blocking Extra Debug Information in metadata]
|
# [Non-blocking Extra Debug Information in metadata]
|
||||||
if turn_off_message_logging is True:
|
if turn_off_message_logging is True:
|
||||||
_metadata["raw_request"] = "redacted by litellm. \
|
_metadata["raw_request"] = (
|
||||||
|
"redacted by litellm. \
|
||||||
'litellm.turn_off_message_logging=True'"
|
'litellm.turn_off_message_logging=True'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
curl_command = self._get_request_curl_command(
|
curl_command = self._get_request_curl_command(
|
||||||
api_base=additional_args.get("api_base", ""),
|
api_base=additional_args.get("api_base", ""),
|
||||||
@ -1029,8 +1031,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_metadata["raw_request"] = "Unable to Log \
|
_metadata["raw_request"] = (
|
||||||
raw request: {}".format(str(e))
|
"Unable to Log \
|
||||||
|
raw request: {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
|
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
|
||||||
try:
|
try:
|
||||||
self.logger_fn(
|
self.logger_fn(
|
||||||
|
|||||||
@ -5533,7 +5533,9 @@ def default_response_schema_prompt(response_schema: dict) -> str:
|
|||||||
prompt_str = """Use this JSON schema:
|
prompt_str = """Use this JSON schema:
|
||||||
```json
|
```json
|
||||||
{}
|
{}
|
||||||
```""".format(response_schema)
|
```""".format(
|
||||||
|
response_schema
|
||||||
|
)
|
||||||
return prompt_str
|
return prompt_str
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
This is a cache for LangfuseLoggers.
|
This is a cache for LangfuseLoggers.
|
||||||
|
|
||||||
Langfuse Python SDK initializes a thread for each client.
|
Langfuse Python SDK initializes a thread for each client.
|
||||||
|
|
||||||
This ensures we do
|
This ensures we do
|
||||||
1. Proper cleanup of Langfuse initialized clients.
|
1. Proper cleanup of Langfuse initialized clients.
|
||||||
2. Re-use created langfuse clients.
|
2. Re-use created langfuse clients.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, cast
|
|||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# SSE parsing helpers (module-level to keep the class lean)
|
# SSE parsing helpers (module-level to keep the class lean)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@ -4,10 +4,10 @@ Support for o1 and o3 model families
|
|||||||
https://platform.openai.com/docs/guides/reasoning
|
https://platform.openai.com/docs/guides/reasoning
|
||||||
|
|
||||||
Translations handled by LiteLLM:
|
Translations handled by LiteLLM:
|
||||||
- modalities: image => drop param (if user opts in to dropping param)
|
- modalities: image => drop param (if user opts in to dropping param)
|
||||||
- role: system ==> translate to role 'user'
|
- role: system ==> translate to role 'user'
|
||||||
- streaming => faked by LiteLLM
|
- streaming => faked by LiteLLM
|
||||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||||
- Logprobs => drop param (if user opts in to dropping param)
|
- Logprobs => drop param (if user opts in to dropping param)
|
||||||
- Temperature => drop param (if user opts in to dropping param)
|
- Temperature => drop param (if user opts in to dropping param)
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
|
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
|
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
|
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
|
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from litellm.secret_managers.main import get_secret_str
|
|||||||
|
|
||||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||||
|
|
||||||
|
|
||||||
BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1"
|
BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Legacy /v1/embedding handler for Bedrock Cohere.
|
Legacy /v1/embedding handler for Bedrock Cohere.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Pre-built response templates
|
# Pre-built response templates
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Cost calculator for Dashscope Chat models.
|
Cost calculator for Dashscope Chat models.
|
||||||
|
|
||||||
Handles tiered pricing and prompt caching scenarios.
|
Handles tiered pricing and prompt caching scenarios.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||||
|
|
||||||
Calls done in OpenAI/openai.py as DataRobot is openai-compatible.
|
Calls done in OpenAI/openai.py as DataRobot is openai-compatible.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format.
|
Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Cost calculator for DeepSeek Chat models.
|
Cost calculator for DeepSeek Chat models.
|
||||||
|
|
||||||
Handles prompt caching scenario.
|
Handles prompt caching scenario.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from litellm.types.utils import all_litellm_params
|
|||||||
|
|
||||||
from ..common_utils import ElevenLabsException
|
from ..common_utils import ElevenLabsException
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
|
|||||||
@ -55,7 +55,7 @@ def _convert_image_to_gemini_format(image_file) -> Dict[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
def _usage_video_resolution_from_parameters(
|
def _usage_video_resolution_from_parameters(
|
||||||
parameters: Dict[str, Any],
|
parameters: Dict[str, Any]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Normalize Veo ``parameters.resolution`` for usage and cost tracking."""
|
"""Normalize Veo ``parameters.resolution`` for usage and cost tracking."""
|
||||||
res = parameters.get("resolution")
|
res = parameters.get("resolution")
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
|
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
|
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format.
|
Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||||
|
|
||||||
Calls done in OpenAI/openai.py as Novita AI is openai-compatible.
|
Calls done in OpenAI/openai.py as Novita AI is openai-compatible.
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
|
Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
|
||||||
|
|
||||||
This is OpenAI compatible
|
This is OpenAI compatible
|
||||||
|
|
||||||
This file only contains param mapping logic
|
This file only contains param mapping logic
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Nvidia NIM embeddings endpoint: https://docs.api.nvidia.com/nim/reference/nvidia-nv-embedqa-e5-v5-infer
|
Nvidia NIM embeddings endpoint: https://docs.api.nvidia.com/nim/reference/nvidia-nv-embedqa-e5-v5-infer
|
||||||
|
|
||||||
This is OpenAI compatible
|
This is OpenAI compatible
|
||||||
|
|
||||||
This file only contains param mapping logic
|
This file only contains param mapping logic
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
Support for o1/o3 model family
|
Support for o1/o3 model family
|
||||||
|
|
||||||
https://platform.openai.com/docs/guides/reasoning
|
https://platform.openai.com/docs/guides/reasoning
|
||||||
|
|
||||||
Translations handled by LiteLLM:
|
Translations handled by LiteLLM:
|
||||||
- modalities: image => drop param (if user opts in to dropping param)
|
- modalities: image => drop param (if user opts in to dropping param)
|
||||||
- role: system ==> translate to role 'user'
|
- role: system ==> translate to role 'user'
|
||||||
- streaming => faked by LiteLLM
|
- streaming => faked by LiteLLM
|
||||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||||
- Logprobs => drop param (if user opts in to dropping param)
|
- Logprobs => drop param (if user opts in to dropping param)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload
|
from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload
|
||||||
|
|||||||
@ -201,7 +201,7 @@ class BaseOpenAILLM:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_openai_client_initialization_param_fields(
|
def get_openai_client_initialization_param_fields(
|
||||||
client_type: Literal["openai", "azure"],
|
client_type: Literal["openai", "azure"]
|
||||||
) -> Tuple[str, ...]:
|
) -> Tuple[str, ...]:
|
||||||
"""Returns a tuple of fields that are used to initialize the OpenAI client"""
|
"""Returns a tuple of fields that are used to initialize the OpenAI client"""
|
||||||
if client_type == "openai":
|
if client_type == "openai":
|
||||||
|
|||||||
@ -49,6 +49,7 @@ from litellm.types.utils import (
|
|||||||
)
|
)
|
||||||
from litellm.llms.openrouter.common_utils import OpenRouterException
|
from litellm.llms.openrouter.common_utils import OpenRouterException
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`
|
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`
|
||||||
|
|
||||||
In the Huggingface TGI format.
|
In the Huggingface TGI format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke`
|
Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke`
|
||||||
|
|
||||||
In the Huggingface TGI format.
|
In the Huggingface TGI format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|||||||
@ -207,7 +207,7 @@ def resolve_resource_group(sources: List[Source]) -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _parse_service_key_once(
|
def _parse_service_key_once(
|
||||||
service_key: Optional[Union[str, dict]],
|
service_key: Optional[Union[str, dict]]
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Pre-parse service_key if it's a string to avoid repeated JSON parsing.
|
Pre-parse service_key if it's a string to avoid repeated JSON parsing.
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from ...openai_like.chat.transformation import OpenAIGPTConfig
|
|||||||
|
|
||||||
from ..utils import SnowflakeBaseConfig
|
from ..utils import SnowflakeBaseConfig
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||||
|
|
||||||
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
|
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Support for OpenAI's `/v1/embeddings` endpoint.
|
Support for OpenAI's `/v1/embeddings` endpoint.
|
||||||
|
|
||||||
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
|
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
|
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic for context caching.
|
Transformation logic for context caching.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
@ -19,7 +19,7 @@ from ..gemini.transformation import (
|
|||||||
|
|
||||||
|
|
||||||
def get_first_continuous_block_idx(
|
def get_first_continuous_block_idx(
|
||||||
filtered_messages: List[Tuple[int, AllMessageValues]], # (idx, message)
|
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Find the array index that ends the first continuous sequence of message blocks.
|
Find the array index that ends the first continuous sequence of message blocks.
|
||||||
|
|||||||
@ -632,14 +632,16 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
|||||||
contents.append(ContentType(role="user", parts=tool_call_responses))
|
contents.append(ContentType(role="user", parts=tool_call_responses))
|
||||||
|
|
||||||
if len(contents) == 0:
|
if len(contents) == 0:
|
||||||
verbose_logger.warning("""
|
verbose_logger.warning(
|
||||||
|
"""
|
||||||
No contents in messages. Contents are required. See
|
No contents in messages. Contents are required. See
|
||||||
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
|
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
|
||||||
If the original request did not comply to OpenAI API requirements it should have failed by now,
|
If the original request did not comply to OpenAI API requirements it should have failed by now,
|
||||||
but LiteLLM does not check for missing messages.
|
but LiteLLM does not check for missing messages.
|
||||||
Setting an empty content to prevent an 400 error.
|
Setting an empty content to prevent an 400 error.
|
||||||
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
|
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
|
contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
|
||||||
return contents
|
return contents
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
|
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -139,7 +139,7 @@ class VertexTextToSpeechAPI(VertexLLM):
|
|||||||
########## End of logging ############
|
########## End of logging ############
|
||||||
####### Send the request ###################
|
####### Send the request ###################
|
||||||
if _is_async is True:
|
if _is_async is True:
|
||||||
return self.async_audio_speech( # type: ignore
|
return self.async_audio_speech( # type:ignore
|
||||||
logging_obj=logging_obj, url=url, headers=headers, request=request
|
logging_obj=logging_obj, url=url, headers=headers, request=request
|
||||||
)
|
)
|
||||||
sync_handler = _get_httpx_client()
|
sync_handler = _get_httpx_client()
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
|
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
|
||||||
|
|
||||||
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
|
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
This module is used to transform the request and response for the Voyage contextualized embeddings API.
|
This module is used to transform the request and response for the Voyage contextualized embeddings API.
|
||||||
This would be used for all the contextualized embeddings models in Voyage.
|
This would be used for all the contextualized embeddings models in Voyage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|||||||
@ -324,7 +324,7 @@ class CustomOpenAPISpec:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_chat_completion_request_schema(
|
def add_chat_completion_request_schema(
|
||||||
openapi_schema: Dict[str, Any],
|
openapi_schema: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
|
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
|
||||||
@ -380,7 +380,7 @@ class CustomOpenAPISpec:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_responses_api_request_schema(
|
def add_responses_api_request_schema(
|
||||||
openapi_schema: Dict[str, Any],
|
openapi_schema: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
|
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
|
||||||
@ -410,7 +410,7 @@ class CustomOpenAPISpec:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_llm_api_request_schema_body(
|
def add_llm_api_request_schema_body(
|
||||||
openapi_schema: Dict[str, Any],
|
openapi_schema: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Add LLM API request schema bodies to OpenAPI specification for documentation.
|
Add LLM API request schema bodies to OpenAPI specification for documentation.
|
||||||
|
|||||||
@ -257,7 +257,7 @@ async def get_form_data(request: Request) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
async def convert_upload_files_to_file_data(
|
async def convert_upload_files_to_file_data(
|
||||||
form_data: Dict[str, Any],
|
form_data: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert FastAPI UploadFile objects to file data tuples for litellm.
|
Convert FastAPI UploadFile objects to file data tuples for litellm.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Contains utils used by OpenAI compatible endpoints
|
Contains utils used by OpenAI compatible endpoints
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Set
|
from typing import Optional, Set
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
What is this?
|
What is this?
|
||||||
|
|
||||||
CRUD endpoints for managing pass-through endpoints
|
CRUD endpoints for managing pass-through endpoints
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -34,7 +34,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
|
|||||||
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
|
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
|
||||||
raise
|
raise
|
||||||
# If an error occurs, the view does not exist, so create it
|
# If an error occurs, the view does not exist, so create it
|
||||||
await db.execute_raw("""
|
await db.execute_raw(
|
||||||
|
"""
|
||||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||||
SELECT
|
SELECT
|
||||||
v.*,
|
v.*,
|
||||||
@ -46,7 +47,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
|
|||||||
FROM "LiteLLM_VerificationToken" v
|
FROM "LiteLLM_VerificationToken" v
|
||||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
|
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
|
||||||
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
|
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
verbose_logger.debug("LiteLLM_VerificationTokenView Created!")
|
verbose_logger.debug("LiteLLM_VerificationTokenView Created!")
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ every text fragment.
|
|||||||
|
|
||||||
from typing import Any, Callable, Dict, FrozenSet, Iterator, List
|
from typing import Any, Callable, Dict, FrozenSet, Iterator, List
|
||||||
|
|
||||||
|
|
||||||
# Call types whose body carries free-form chat / prompt text that
|
# Call types whose body carries free-form chat / prompt text that
|
||||||
# text-content guardrails (banned keywords, content moderation, secret
|
# text-content guardrails (banned keywords, content moderation, secret
|
||||||
# detection, …) should inspect. The proxy ingress passes ``route_type``
|
# detection, …) should inspect. The proxy ingress passes ``route_type``
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from litellm.types.guardrails import SupportedGuardrailIntegrations
|
|||||||
|
|
||||||
from .akto import AktoGuardrail
|
from .akto import AktoGuardrail
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ The actual skill logic is in litellm/llms/litellm_proxy/skills/.
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook
|
from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook
|
||||||
|
|
||||||
# Register hook in proxy
|
# Register hook in proxy
|
||||||
litellm.callbacks.append(SkillsInjectionHook())
|
litellm.callbacks.append(SkillsInjectionHook())
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
BUDGET MANAGEMENT
|
BUDGET MANAGEMENT
|
||||||
|
|
||||||
All /budget management endpoints
|
All /budget management endpoints
|
||||||
|
|
||||||
/budget/new
|
/budget/new
|
||||||
/budget/info
|
/budget/info
|
||||||
/budget/update
|
/budget/update
|
||||||
/budget/delete
|
/budget/delete
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
CUSTOMER MANAGEMENT
|
CUSTOMER MANAGEMENT
|
||||||
|
|
||||||
All /customer management endpoints
|
All /customer management endpoints
|
||||||
|
|
||||||
/customer/new
|
/customer/new
|
||||||
/customer/info
|
/customer/info
|
||||||
/customer/update
|
/customer/update
|
||||||
/customer/delete
|
/customer/delete
|
||||||
|
|||||||
@ -529,7 +529,7 @@ async def _update_existing_team_model_assignment(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_team_public_model_name(
|
def _get_team_public_model_name(
|
||||||
model_info: Optional[Union[dict, str]],
|
model_info: Optional[Union[dict, str]]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if isinstance(model_info, dict):
|
if isinstance(model_info, dict):
|
||||||
value = model_info.get("team_public_model_name")
|
value = model_info.get("team_public_model_name")
|
||||||
|
|||||||
@ -7,7 +7,7 @@ variables.
|
|||||||
|
|
||||||
Environment Variables:
|
Environment Variables:
|
||||||
- MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL
|
- MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL
|
||||||
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
|
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
|
||||||
- MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL
|
- MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL
|
||||||
|
|
||||||
If these are not set, the default Microsoft endpoints are used.
|
If these are not set, the default Microsoft endpoints are used.
|
||||||
|
|||||||
@ -4347,7 +4347,9 @@ async def list_team(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
team_exception = """Invalid team object for team_id: {}. team_object={}.
|
team_exception = """Invalid team object for team_id: {}. team_object={}.
|
||||||
Error: {}
|
Error: {}
|
||||||
""".format(team.team_id, team.model_dump(), str(e))
|
""".format(
|
||||||
|
team.team_id, team.model_dump(), str(e)
|
||||||
|
)
|
||||||
verbose_proxy_logger.exception(team_exception)
|
verbose_proxy_logger.exception(team_exception)
|
||||||
continue
|
continue
|
||||||
# Sort the responses by team_alias
|
# Sort the responses by team_alias
|
||||||
|
|||||||
@ -3,7 +3,7 @@ User Agent Analytics Endpoints
|
|||||||
|
|
||||||
This module provides optimized endpoints for tracking user agent activity metrics including:
|
This module provides optimized endpoints for tracking user agent activity metrics including:
|
||||||
- Daily Active Users (DAU) by tags for configurable number of days
|
- Daily Active Users (DAU) by tags for configurable number of days
|
||||||
- Weekly Active Users (WAU) by tags for configurable number of weeks
|
- Weekly Active Users (WAU) by tags for configurable number of weeks
|
||||||
- Monthly Active Users (MAU) by tags for configurable number of months
|
- Monthly Active Users (MAU) by tags for configurable number of months
|
||||||
- Summary analytics by tags
|
- Summary analytics by tags
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from litellm.litellm_core_utils.litellm_logging import (
|
|||||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||||
from litellm.types.utils import StandardPassThroughResponseObject
|
from litellm.types.utils import StandardPassThroughResponseObject
|
||||||
|
|
||||||
|
|
||||||
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
|
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
|
||||||
"POST /v0/agents": "cursor:agent:create",
|
"POST /v0/agents": "cursor:agent:create",
|
||||||
"GET /v0/agents": "cursor:agent:list",
|
"GET /v0/agents": "cursor:agent:list",
|
||||||
|
|||||||
@ -292,7 +292,9 @@ class ProxyInitializationHelpers:
|
|||||||
_endpoint_str = (
|
_endpoint_str = (
|
||||||
f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\"
|
f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\"
|
||||||
)
|
)
|
||||||
curl_command = _endpoint_str + """
|
curl_command = (
|
||||||
|
_endpoint_str
|
||||||
|
+ """
|
||||||
--header 'Content-Type: application/json' \\
|
--header 'Content-Type: application/json' \\
|
||||||
--data ' {
|
--data ' {
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
@ -305,6 +307,7 @@ class ProxyInitializationHelpers:
|
|||||||
}'
|
}'
|
||||||
\n
|
\n
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
print() # noqa
|
print() # noqa
|
||||||
print( # noqa
|
print( # noqa
|
||||||
'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n'
|
'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n'
|
||||||
@ -380,9 +383,11 @@ class ProxyInitializationHelpers:
|
|||||||
with open(os.devnull, "w") as devnull:
|
with open(os.devnull, "w") as devnull:
|
||||||
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"""
|
print( # noqa
|
||||||
|
f"""
|
||||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||||
""") # noqa # noqa
|
"""
|
||||||
|
) # noqa
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_port_in_use(port):
|
def _is_port_in_use(port):
|
||||||
|
|||||||
@ -2688,9 +2688,11 @@ def run_ollama_serve():
|
|||||||
with open(os.devnull, "w") as devnull:
|
with open(os.devnull, "w") as devnull:
|
||||||
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(f"""
|
verbose_proxy_logger.debug(
|
||||||
|
f"""
|
||||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_process_rss_mb() -> Optional[float]:
|
def _get_process_rss_mb() -> Optional[float]:
|
||||||
|
|||||||
@ -3184,14 +3184,16 @@ async def provider_budgets() -> ProviderBudgetResponse:
|
|||||||
async def get_spend_by_tags(
|
async def get_spend_by_tags(
|
||||||
prisma_client: PrismaClient, start_date=None, end_date=None
|
prisma_client: PrismaClient, start_date=None, end_date=None
|
||||||
):
|
):
|
||||||
response = await prisma_client.db.query_raw("""
|
response = await prisma_client.db.query_raw(
|
||||||
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||||
COUNT(*) AS log_count,
|
COUNT(*) AS log_count,
|
||||||
SUM(spend) AS total_spend
|
SUM(spend) AS total_spend
|
||||||
FROM "LiteLLM_SpendLogs"
|
FROM "LiteLLM_SpendLogs"
|
||||||
GROUP BY individual_request_tag;
|
GROUP BY individual_request_tag;
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@ -2712,7 +2712,8 @@ class PrismaClient:
|
|||||||
required_view = "LiteLLM_VerificationTokenView"
|
required_view = "LiteLLM_VerificationTokenView"
|
||||||
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
||||||
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
|
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
|
||||||
ret = await self.db.query_raw(f"""
|
ret = await self.db.query_raw(
|
||||||
|
f"""
|
||||||
WITH existing_views AS (
|
WITH existing_views AS (
|
||||||
SELECT viewname
|
SELECT viewname
|
||||||
FROM pg_views
|
FROM pg_views
|
||||||
@ -2724,7 +2725,8 @@ class PrismaClient:
|
|||||||
(SELECT COUNT(*) FROM existing_views) AS view_count,
|
(SELECT COUNT(*) FROM existing_views) AS view_count,
|
||||||
ARRAY_AGG(viewname) AS view_names
|
ARRAY_AGG(viewname) AS view_names
|
||||||
FROM existing_views
|
FROM existing_views
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
expected_total_views = len(expected_views)
|
expected_total_views = len(expected_views)
|
||||||
if ret[0]["view_count"] == expected_total_views:
|
if ret[0]["view_count"] == expected_total_views:
|
||||||
verbose_proxy_logger.info("All necessary views exist!")
|
verbose_proxy_logger.info("All necessary views exist!")
|
||||||
@ -2733,7 +2735,8 @@ class PrismaClient:
|
|||||||
## check if required view exists ##
|
## check if required view exists ##
|
||||||
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
|
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
|
||||||
await self.health_check() # make sure we can connect to db
|
await self.health_check() # make sure we can connect to db
|
||||||
await self.db.execute_raw("""
|
await self.db.execute_raw(
|
||||||
|
"""
|
||||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||||
SELECT
|
SELECT
|
||||||
v.*,
|
v.*,
|
||||||
@ -2743,7 +2746,8 @@ class PrismaClient:
|
|||||||
t.rpm_limit AS team_rpm_limit
|
t.rpm_limit AS team_rpm_limit
|
||||||
FROM "LiteLLM_VerificationToken" v
|
FROM "LiteLLM_VerificationToken" v
|
||||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
"LiteLLM_VerificationTokenView Created in DB!"
|
"LiteLLM_VerificationTokenView Created in DB!"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
What is this?
|
What is this?
|
||||||
|
|
||||||
Logging Pass-Through Endpoints
|
Logging Pass-Through Endpoints
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -826,7 +826,7 @@ class Router:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_strategy(
|
def _normalize_strategy(
|
||||||
strategy: Union[RoutingStrategy, str, None],
|
strategy: Union[RoutingStrategy, str, None]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -103,7 +103,7 @@ def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str
|
|||||||
|
|
||||||
|
|
||||||
def _recent_tool_results(
|
def _recent_tool_results(
|
||||||
messages: Optional[List[Dict[str, Any]]],
|
messages: Optional[List[Dict[str, Any]]]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Extract the current turn's tool result payloads from the request messages.
|
"""Extract the current turn's tool result payloads from the request messages.
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user