Drop dep bumps + black-26 reformat to clear fork CI policy

PR was blocked by .github/workflows/guard-fork-dependencies.yml: fork PRs
cannot modify uv.lock. Reverting:

- uv.lock + pyproject.toml black bump (24.10.0 -> 26.3.1) and the 295
  files of mechanical Black 26 reformat coupled to it
- pyproject.toml diskcache extra change (kept the runtime mitigation in
  litellm/caching/disk_cache.py via JSONDisk)

Kept:
- Dockerfile cache narrowing (drops ~660 MB of uv build cache that
  surfaced cached setuptools as CVE findings)
- litellm/caching/disk_cache.py: dc.JSONDisk to neutralize CVE-2025-69872
- ui/litellm-dashboard/package-lock.json + litellm-js/spend-logs/package-lock.json:
  next/postcss/hono/uuid CVE bumps (these are not blocked by the fork guard)
- tests/test_litellm/caching/test_disk_cache.py
- tests/code_coverage_tests/liccheck.ini: harmless black authorization

Black + gitpython + langchain dep upgrades will need a follow-up from a
maintainer pushing a branch in the canonical BerriAI/litellm repo.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
user 2026-05-07 23:04:52 +00:00
parent 63bda3f001
commit 5bafa8b3a2
No known key found for this signature in database
292 changed files with 1814 additions and 2305 deletions

View File

@ -2,7 +2,6 @@ import asyncio
import aiohttp
import json
# Asynchronously fetch data from a given URL
async def fetch_data(url):
try:
@ -16,24 +15,22 @@ async def fetch_data(url):
resp_json = await resp.json()
print("Fetch the data from URL.")
# Return the 'data' field from the JSON response
return resp_json["data"]
return resp_json['data']
except Exception as e:
# Print an error message if fetching data fails
print("Error fetching data from URL:", e)
return None
# Synchronize local data with remote data
def sync_local_data_with_remote(local_data, remote_data):
# Update existing keys in local_data with values from remote_data
for key in set(local_data) & set(remote_data):
for key in (set(local_data) & set(remote_data)):
local_data[key].update(remote_data[key])
# Add new keys from remote_data to local_data
for key in set(remote_data) - set(local_data):
for key in (set(remote_data) - set(local_data)):
local_data[key] = remote_data[key]
# Write data to the json file
def write_to_file(file_path, data):
try:
@ -46,7 +43,6 @@ def write_to_file(file_path, data):
# Print an error message if writing to file fails
print("Error updating JSON file:", e)
# Update the existing models and add the missing models for OpenRouter
def transform_openrouter_data(data):
transformed = {}
@ -58,41 +54,33 @@ def transform_openrouter_data(data):
}
# Add 'max_output_tokens' as a field if it is not None
if (
"top_provider" in row
and "max_completion_tokens" in row["top_provider"]
and row["top_provider"]["max_completion_tokens"] is not None
):
obj["max_output_tokens"] = int(row["top_provider"]["max_completion_tokens"])
if "top_provider" in row and "max_completion_tokens" in row["top_provider"] and row["top_provider"]["max_completion_tokens"] is not None:
obj['max_output_tokens'] = int(row["top_provider"]["max_completion_tokens"])
# Add the field 'output_cost_per_token'
obj.update(
{
"output_cost_per_token": float(row["pricing"]["completion"]),
}
)
obj.update({
"output_cost_per_token": float(row["pricing"]["completion"]),
})
# Add field 'input_cost_per_image' if it exists and is non-zero
if (
"pricing" in row
and "image" in row["pricing"]
and float(row["pricing"]["image"]) != 0.0
):
obj["input_cost_per_image"] = float(row["pricing"]["image"])
if "pricing" in row and "image" in row["pricing"] and float(row["pricing"]["image"]) != 0.0:
obj['input_cost_per_image'] = float(row["pricing"]["image"])
# Add the fields 'litellm_provider' and 'mode'
obj.update({"litellm_provider": "openrouter", "mode": "chat"})
obj.update({
"litellm_provider": "openrouter",
"mode": "chat"
})
# Add the 'supports_vision' field if the modality is 'multimodal'
if row.get("architecture", {}).get("modality") == "multimodal":
obj["supports_vision"] = True
if row.get('architecture', {}).get('modality') == 'multimodal':
obj['supports_vision'] = True
# Use a composite key to store the transformed object
transformed[f'openrouter/{row["id"]}'] = obj
return transformed
# Update the existing models and add the missing models for Vercel AI Gateway
def transform_vercel_ai_gateway_data(data):
transformed = {}
@ -101,30 +89,20 @@ def transform_vercel_ai_gateway_data(data):
"max_tokens": row["context_window"],
"input_cost_per_token": float(row["pricing"]["input"]),
"output_cost_per_token": float(row["pricing"]["output"]),
"max_output_tokens": row["max_tokens"],
"max_input_tokens": row["context_window"],
'max_output_tokens': row['max_tokens'],
'max_input_tokens': row["context_window"],
}
# Handle cache pricing if available
if "pricing" in row:
if (
"input_cache_read" in row["pricing"]
and row["pricing"]["input_cache_read"] is not None
):
obj["cache_read_input_token_cost"] = float(
f"{float(row['pricing']['input_cache_read']):e}"
)
if (
"input_cache_write" in row["pricing"]
and row["pricing"]["input_cache_write"] is not None
):
obj["cache_creation_input_token_cost"] = float(
f"{float(row['pricing']['input_cache_write']):e}"
)
if "input_cache_read" in row["pricing"] and row["pricing"]["input_cache_read"] is not None:
obj['cache_read_input_token_cost'] = float(f"{float(row['pricing']['input_cache_read']):e}")
if "input_cache_write" in row["pricing"] and row["pricing"]["input_cache_write"] is not None:
obj['cache_creation_input_token_cost'] = float(f"{float(row['pricing']['input_cache_write']):e}")
mode = "embedding" if "embedding" in row["id"].lower() else "chat"
obj.update({"litellm_provider": "vercel_ai_gateway", "mode": mode})
transformed[f'vercel_ai_gateway/{row["id"]}'] = obj
@ -148,31 +126,24 @@ def load_local_data(file_path):
print("Error decoding JSON:", e)
return None
def main():
local_file_path = (
"model_prices_and_context_window.json" # Path to the local data file
)
openrouter_url = (
"https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
)
vercel_ai_gateway_url = (
"https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
)
local_file_path = "model_prices_and_context_window.json" # Path to the local data file
openrouter_url = "https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
vercel_ai_gateway_url = "https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
# Load local data from file
local_data = load_local_data(local_file_path)
# Fetch OpenRouter data
openrouter_data = asyncio.run(fetch_data(openrouter_url))
# Transform the fetched OpenRouter data
openrouter_data = transform_openrouter_data(openrouter_data)
# Fetch Vercel AI Gateway data
vercel_data = asyncio.run(fetch_data(vercel_ai_gateway_url))
# Transform the fetched Vercel AI Gateway data
vercel_data = transform_vercel_ai_gateway_data(vercel_data)
# Combine both datasets
all_remote_data = {**openrouter_data, **vercel_data}
@ -183,7 +154,6 @@ def main():
else:
print("Failed to fetch model data from either local file or URL.")
# Entry point of the script
if __name__ == "__main__":
main()

View File

@ -16,75 +16,64 @@ from pathlib import Path
import json
from typing import Dict, List, Tuple, Optional
# ANSI color codes for terminal output
class Colors:
GREEN = "\033[92m"
RED = "\033[91m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
PURPLE = "\033[95m"
CYAN = "\033[96m"
RESET = "\033[0m"
BOLD = "\033[1m"
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
PURPLE = '\033[95m'
CYAN = '\033[96m'
RESET = '\033[0m'
BOLD = '\033[1m'
def print_colored(message: str, color: str = Colors.RESET):
"""Print colored message to terminal"""
print(f"{color}{message}{Colors.RESET}")
def get_provider_from_test_file(test_file: str) -> str:
"""Map test file names to provider names"""
provider_mapping = {
"test_anthropic": "Anthropic",
"test_azure": "Azure",
"test_bedrock": "AWS Bedrock",
"test_openai": "OpenAI",
"test_vertex": "Google Vertex AI",
"test_gemini": "Google Vertex AI",
"test_cohere": "Cohere",
"test_databricks": "Databricks",
"test_groq": "Groq",
"test_together": "Together AI",
"test_mistral": "Mistral",
"test_deepseek": "DeepSeek",
"test_replicate": "Replicate",
"test_huggingface": "HuggingFace",
"test_fireworks": "Fireworks AI",
"test_perplexity": "Perplexity",
"test_cloudflare": "Cloudflare",
"test_voyage": "Voyage AI",
"test_xai": "xAI",
"test_nvidia": "NVIDIA",
"test_watsonx": "IBM watsonx",
"test_azure_ai": "Azure AI",
"test_snowflake": "Snowflake",
"test_infinity": "Infinity",
"test_jina": "Jina AI",
"test_deepgram": "Deepgram",
"test_clarifai": "Clarifai",
"test_triton": "Triton",
'test_anthropic': 'Anthropic',
'test_azure': 'Azure',
'test_bedrock': 'AWS Bedrock',
'test_openai': 'OpenAI',
'test_vertex': 'Google Vertex AI',
'test_gemini': 'Google Vertex AI',
'test_cohere': 'Cohere',
'test_databricks': 'Databricks',
'test_groq': 'Groq',
'test_together': 'Together AI',
'test_mistral': 'Mistral',
'test_deepseek': 'DeepSeek',
'test_replicate': 'Replicate',
'test_huggingface': 'HuggingFace',
'test_fireworks': 'Fireworks AI',
'test_perplexity': 'Perplexity',
'test_cloudflare': 'Cloudflare',
'test_voyage': 'Voyage AI',
'test_xai': 'xAI',
'test_nvidia': 'NVIDIA',
'test_watsonx': 'IBM watsonx',
'test_azure_ai': 'Azure AI',
'test_snowflake': 'Snowflake',
'test_infinity': 'Infinity',
'test_jina': 'Jina AI',
'test_deepgram': 'Deepgram',
'test_clarifai': 'Clarifai',
'test_triton': 'Triton',
}
for key, provider in provider_mapping.items():
if key in test_file:
return provider
# For cross-provider test files
if any(
name in test_file
for name in [
"test_optional_params",
"test_prompt_factory",
"test_router",
"test_text_completion",
]
):
return f"Cross-Provider Tests ({test_file})"
return "Other Tests"
if any(name in test_file for name in ['test_optional_params', 'test_prompt_factory',
'test_router', 'test_text_completion']):
return f'Cross-Provider Tests ({test_file})'
return 'Other Tests'
def format_duration(seconds: float) -> str:
"""Format duration in human-readable format"""
@ -100,355 +89,290 @@ def format_duration(seconds: float) -> str:
return f"{hours}h {minutes}m"
def generate_markdown_report(
junit_xml_path: str, output_path: str, tag: str = None, commit: str = None
):
def generate_markdown_report(junit_xml_path: str, output_path: str, tag: str = None, commit: str = None):
"""Generate a beautiful markdown report from JUnit XML"""
try:
tree = ET.parse(junit_xml_path)
root = tree.getroot()
# Handle both testsuite and testsuites root
if root.tag == "testsuites":
suites = root.findall("testsuite")
if root.tag == 'testsuites':
suites = root.findall('testsuite')
else:
suites = [root]
# Overall statistics
total_tests = 0
total_failures = 0
total_errors = 0
total_skipped = 0
total_time = 0.0
# Provider breakdown
provider_stats = defaultdict(
lambda: {"passed": 0, "failed": 0, "skipped": 0, "errors": 0, "time": 0.0}
)
provider_stats = defaultdict(lambda: {'passed': 0, 'failed': 0, 'skipped': 0, 'errors': 0, 'time': 0.0})
provider_tests = defaultdict(list)
for suite in suites:
total_tests += int(suite.get("tests", 0))
total_failures += int(suite.get("failures", 0))
total_errors += int(suite.get("errors", 0))
total_skipped += int(suite.get("skipped", 0))
total_time += float(suite.get("time", 0))
for testcase in suite.findall("testcase"):
classname = testcase.get("classname", "")
test_name = testcase.get("name", "")
test_time = float(testcase.get("time", 0))
total_tests += int(suite.get('tests', 0))
total_failures += int(suite.get('failures', 0))
total_errors += int(suite.get('errors', 0))
total_skipped += int(suite.get('skipped', 0))
total_time += float(suite.get('time', 0))
for testcase in suite.findall('testcase'):
classname = testcase.get('classname', '')
test_name = testcase.get('name', '')
test_time = float(testcase.get('time', 0))
# Extract test file name from classname
if "." in classname:
parts = classname.split(".")
test_file = parts[-2] if len(parts) > 1 else "unknown"
if '.' in classname:
parts = classname.split('.')
test_file = parts[-2] if len(parts) > 1 else 'unknown'
else:
test_file = "unknown"
test_file = 'unknown'
provider = get_provider_from_test_file(test_file)
provider_stats[provider]["time"] += test_time
provider_stats[provider]['time'] += test_time
# Check test status
if testcase.find("failure") is not None:
provider_stats[provider]["failed"] += 1
failure = testcase.find("failure")
failure_msg = (
failure.get("message", "") if failure is not None else ""
)
provider_tests[provider].append(
{
"name": test_name,
"status": "FAILED",
"time": test_time,
"message": failure_msg,
}
)
elif testcase.find("error") is not None:
provider_stats[provider]["errors"] += 1
error = testcase.find("error")
error_msg = error.get("message", "") if error is not None else ""
provider_tests[provider].append(
{
"name": test_name,
"status": "ERROR",
"time": test_time,
"message": error_msg,
}
)
elif testcase.find("skipped") is not None:
provider_stats[provider]["skipped"] += 1
skip = testcase.find("skipped")
skip_msg = skip.get("message", "") if skip is not None else ""
provider_tests[provider].append(
{
"name": test_name,
"status": "SKIPPED",
"time": test_time,
"message": skip_msg,
}
)
if testcase.find('failure') is not None:
provider_stats[provider]['failed'] += 1
failure = testcase.find('failure')
failure_msg = failure.get('message', '') if failure is not None else ''
provider_tests[provider].append({
'name': test_name,
'status': 'FAILED',
'time': test_time,
'message': failure_msg
})
elif testcase.find('error') is not None:
provider_stats[provider]['errors'] += 1
error = testcase.find('error')
error_msg = error.get('message', '') if error is not None else ''
provider_tests[provider].append({
'name': test_name,
'status': 'ERROR',
'time': test_time,
'message': error_msg
})
elif testcase.find('skipped') is not None:
provider_stats[provider]['skipped'] += 1
skip = testcase.find('skipped')
skip_msg = skip.get('message', '') if skip is not None else ''
provider_tests[provider].append({
'name': test_name,
'status': 'SKIPPED',
'time': test_time,
'message': skip_msg
})
else:
provider_stats[provider]["passed"] += 1
provider_tests[provider].append(
{
"name": test_name,
"status": "PASSED",
"time": test_time,
"message": "",
}
)
provider_stats[provider]['passed'] += 1
provider_tests[provider].append({
'name': test_name,
'status': 'PASSED',
'time': test_time,
'message': ''
})
passed = total_tests - total_failures - total_errors - total_skipped
# Generate the markdown report
with open(output_path, "w") as f:
with open(output_path, 'w') as f:
# Header
f.write("# LLM Translation Test Results\n\n")
# Metadata table
f.write("## Test Run Information\n\n")
f.write("| Field | Value |\n")
f.write("|-------|-------|\n")
f.write(f"| **Tag** | `{tag or 'N/A'}` |\n")
f.write(
f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n"
)
f.write(f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n")
f.write(f"| **Commit** | `{commit or 'N/A'}` |\n")
f.write(f"| **Duration** | {format_duration(total_time)} |\n")
f.write("\n")
# Overall statistics with visual elements
f.write("## Overall Statistics\n\n")
# Summary box
f.write("```\n")
f.write(f"Total Tests: {total_tests}\n")
f.write(
f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(
f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(
f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(
f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write(f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write(f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write(f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write("```\n\n")
# Provider summary table
f.write("## Results by Provider\n\n")
f.write(
"| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n"
)
f.write(
"|----------|-------|------|------|-------|------|-----------|----------|"
)
f.write("| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n")
f.write("|----------|-------|------|------|-------|------|-----------|----------|")
# Sort providers: specific providers first, then cross-provider tests
sorted_providers = []
cross_provider = []
for p in sorted(provider_stats.keys()):
if "Cross-Provider" in p or p == "Other Tests":
if 'Cross-Provider' in p or p == 'Other Tests':
cross_provider.append(p)
else:
sorted_providers.append(p)
all_providers = sorted_providers + cross_provider
for provider in all_providers:
stats = provider_stats[provider]
total = (
stats["passed"]
+ stats["failed"]
+ stats["errors"]
+ stats["skipped"]
)
pass_rate = (stats["passed"] / total * 100) if total > 0 else 0
f.write(
f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | "
)
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
pass_rate = (stats['passed'] / total * 100) if total > 0 else 0
f.write(f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | ")
f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ")
f.write(f"{format_duration(stats['time'])} |")
# Detailed test results by provider
f.write("\n\n## Detailed Test Results\n\n")
for provider in sorted_providers:
if provider_tests[provider]:
stats = provider_stats[provider]
total = (
stats["passed"]
+ stats["failed"]
+ stats["errors"]
+ stats["skipped"]
)
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
f.write(f"### {provider}\n\n")
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
f.write(
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) "
)
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) ")
f.write(f"in {format_duration(stats['time'])}\n\n")
# Group tests by status
tests_by_status = defaultdict(list)
for test in provider_tests[provider]:
tests_by_status[test["status"]].append(test)
tests_by_status[test['status']].append(test)
# Show failed tests first (if any)
if tests_by_status["FAILED"]:
if tests_by_status['FAILED']:
f.write("<details>\n<summary>Failed Tests</summary>\n\n")
for test in tests_by_status["FAILED"]:
for test in tests_by_status['FAILED']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
if test["message"]:
if test['message']:
# Truncate long error messages
msg = (
test["message"][:200] + "..."
if len(test["message"]) > 200
else test["message"]
)
msg = test['message'][:200] + '...' if len(test['message']) > 200 else test['message']
f.write(f" > {msg}\n")
f.write("\n</details>\n\n")
# Show errors (if any)
if tests_by_status["ERROR"]:
if tests_by_status['ERROR']:
f.write("<details>\n<summary>Error Tests</summary>\n\n")
for test in tests_by_status["ERROR"]:
for test in tests_by_status['ERROR']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
f.write("\n</details>\n\n")
# Show passed tests in collapsible section
if tests_by_status["PASSED"]:
if tests_by_status['PASSED']:
f.write("<details>\n<summary>Passed Tests</summary>\n\n")
for test in tests_by_status["PASSED"]:
for test in tests_by_status['PASSED']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
f.write("\n</details>\n\n")
# Show skipped tests (if any)
if tests_by_status["SKIPPED"]:
if tests_by_status['SKIPPED']:
f.write("<details>\n<summary>Skipped Tests</summary>\n\n")
for test in tests_by_status["SKIPPED"]:
for test in tests_by_status['SKIPPED']:
f.write(f"- `{test['name']}`\n")
f.write("\n</details>\n\n")
# Cross-provider tests in a separate section
if cross_provider:
f.write("### Cross-Provider Tests\n\n")
for provider in cross_provider:
if provider_tests[provider]:
stats = provider_stats[provider]
total = (
stats["passed"]
+ stats["failed"]
+ stats["errors"]
+ stats["skipped"]
)
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
f.write(f"#### {provider}\n\n")
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
f.write(
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n"
)
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n")
# For cross-provider tests, just show counts
f.write(f"- Passed: {stats['passed']}\n")
if stats["failed"] > 0:
if stats['failed'] > 0:
f.write(f"- Failed: {stats['failed']}\n")
if stats["errors"] > 0:
if stats['errors'] > 0:
f.write(f"- Errors: {stats['errors']}\n")
if stats["skipped"] > 0:
if stats['skipped'] > 0:
f.write(f"- Skipped: {stats['skipped']}\n")
f.write("\n")
print_colored(f"Report generated: {output_path}", Colors.GREEN)
except Exception as e:
print_colored(f"Error generating report: {e}", Colors.RED)
raise
def run_tests(
test_path: str = "tests/llm_translation/",
junit_xml: str = "test-results/junit.xml",
report_path: str = "test-results/llm_translation_report.md",
tag: str = None,
commit: str = None,
) -> int:
def run_tests(test_path: str = "tests/llm_translation/",
junit_xml: str = "test-results/junit.xml",
report_path: str = "test-results/llm_translation_report.md",
tag: str = None,
commit: str = None) -> int:
"""Run the LLM translation tests and generate report"""
# Create test results directory
os.makedirs(os.path.dirname(junit_xml), exist_ok=True)
print_colored("Starting LLM Translation Tests", Colors.BOLD + Colors.BLUE)
print_colored(f"Test directory: {test_path}", Colors.CYAN)
print_colored(f"Output: {junit_xml}", Colors.CYAN)
print()
# Run pytest
cmd = [
"uv",
"run",
"--no-sync",
"pytest",
test_path,
"uv", "run", "--no-sync", "pytest", test_path,
f"--junitxml={junit_xml}",
"-v",
"--tb=short",
"--maxfail=500",
"-n",
"auto",
"-n", "auto"
]
# Add timeout if pytest-timeout is installed
try:
subprocess.run(
["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
capture_output=True,
check=True,
)
subprocess.run(["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
capture_output=True, check=True)
cmd.extend(["--timeout=300"])
except:
print_colored(
"Warning: pytest-timeout not installed, skipping timeout option",
Colors.YELLOW,
)
print_colored("Warning: pytest-timeout not installed, skipping timeout option", Colors.YELLOW)
print_colored("Running pytest with command:", Colors.YELLOW)
print(f" {' '.join(cmd)}")
print()
# Run the tests
result = subprocess.run(cmd, capture_output=False)
# Generate the report regardless of test outcome
if os.path.exists(junit_xml):
print()
print_colored("Generating test report...", Colors.BLUE)
generate_markdown_report(junit_xml, report_path, tag, commit)
# Print summary to console
print()
print_colored("Test Summary:", Colors.BOLD + Colors.PURPLE)
# Parse XML for quick summary
tree = ET.parse(junit_xml)
root = tree.getroot()
if root.tag == "testsuites":
suites = root.findall("testsuite")
if root.tag == 'testsuites':
suites = root.findall('testsuite')
else:
suites = [root]
total = sum(int(s.get("tests", 0)) for s in suites)
failures = sum(int(s.get("failures", 0)) for s in suites)
errors = sum(int(s.get("errors", 0)) for s in suites)
skipped = sum(int(s.get("skipped", 0)) for s in suites)
total = sum(int(s.get('tests', 0)) for s in suites)
failures = sum(int(s.get('failures', 0)) for s in suites)
errors = sum(int(s.get('errors', 0)) for s in suites)
skipped = sum(int(s.get('skipped', 0)) for s in suites)
passed = total - failures - errors - skipped
print(f" Total: {total}")
print_colored(f" Passed: {passed}", Colors.GREEN)
if failures > 0:
@ -457,75 +381,59 @@ def run_tests(
print_colored(f" Errors: {errors}", Colors.RED)
if skipped > 0:
print_colored(f" Skipped: {skipped}", Colors.YELLOW)
if total > 0:
pass_rate = (passed / total) * 100
color = (
Colors.GREEN
if pass_rate >= 80
else Colors.YELLOW if pass_rate >= 60 else Colors.RED
)
color = Colors.GREEN if pass_rate >= 80 else Colors.YELLOW if pass_rate >= 60 else Colors.RED
print_colored(f" Pass Rate: {pass_rate:.1f}%", color)
else:
print_colored("No test results found!", Colors.RED)
print()
print_colored("Test run complete!", Colors.BOLD + Colors.GREEN)
return result.returncode
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run LLM Translation Tests")
parser.add_argument(
"--test-path", default="tests/llm_translation/", help="Path to test directory"
)
parser.add_argument(
"--junit-xml",
default="test-results/junit.xml",
help="Path for JUnit XML output",
)
parser.add_argument(
"--report",
default="test-results/llm_translation_report.md",
help="Path for markdown report",
)
parser.add_argument("--test-path", default="tests/llm_translation/",
help="Path to test directory")
parser.add_argument("--junit-xml", default="test-results/junit.xml",
help="Path for JUnit XML output")
parser.add_argument("--report", default="test-results/llm_translation_report.md",
help="Path for markdown report")
parser.add_argument("--tag", help="Git tag or version")
parser.add_argument("--commit", help="Git commit SHA")
args = parser.parse_args()
# Get git info if not provided
if not args.commit:
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"], capture_output=True, text=True
)
result = subprocess.run(["git", "rev-parse", "HEAD"],
capture_output=True, text=True)
if result.returncode == 0:
args.commit = result.stdout.strip()
except:
pass
if not args.tag:
try:
result = subprocess.run(
["git", "describe", "--tags", "--abbrev=0"],
capture_output=True,
text=True,
)
result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"],
capture_output=True, text=True)
if result.returncode == 0:
args.tag = result.stdout.strip()
except:
pass
exit_code = run_tests(
test_path=args.test_path,
junit_xml=args.junit_xml,
report_path=args.report,
tag=args.tag,
commit=args.commit,
commit=args.commit
)
sys.exit(exit_code)

View File

@ -9,6 +9,7 @@ from pathlib import Path
import testing.postgresql
DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE)
DEFAULT_BASE_BRANCH = "litellm_internal_staging"

View File

@ -6,6 +6,7 @@ from tabulate import tabulate
from termcolor import colored
import os
# Define the list of models to benchmark
# select any LLM listed here: https://docs.litellm.ai/docs/providers
models = ["gpt-3.5-turbo", "claude-2"]

View File

@ -22,6 +22,7 @@ from fastapi import FastAPI, HTTPException, Header, Query
from fastapi.responses import JSONResponse
import uvicorn
app = FastAPI(
title="Braintrust Prompt Wrapper",
description="Wrapper server for Braintrust prompts to work with LiteLLM",

View File

@ -2,7 +2,7 @@
Simple xAI Voice Agent using LiveKit SDK with LiteLLM Gateway
This example shows how to use LiveKit's xAI realtime plugin through LiteLLM proxy.
LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI,
LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI,
and Azure realtime APIs without changing your agent code.
"""

View File

@ -1,14 +1,14 @@
"""
LiteLLM Migration Script!
Takes a config.yaml and calls /model/new
Takes a config.yaml and calls /model/new
Inputs:
- File path to config.yaml
- Proxy base url to your hosted proxy
Step 1: Reads your config.yaml
Step 2: reads `model_list` and loops through all models
Step 2: reads `model_list` and loops through all models
Step 3: calls `<proxy-base-url>/model/new` for each model
"""

View File

@ -518,7 +518,8 @@ if __name__ == "__main__":
print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply")
print("=" * 80)
print("\nExample curl command:")
print(f"""
print(
f"""
curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
-H "Authorization: Bearer {bearer_token}" \\
-H "Content-Type: application/json" \\
@ -532,7 +533,8 @@ curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
}}
]
}}'
""")
"""
)
print("=" * 80)
uvicorn.run(app, host=host, port=port)

View File

@ -34,7 +34,8 @@ async def check_view_exists(): # noqa: PLR0915
print("LiteLLM_VerificationTokenView Exists!") # noqa
except Exception:
# If an error occurs, the view does not exist, so create it
await db.execute_raw("""
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
@ -44,7 +45,8 @@ async def check_view_exists(): # noqa: PLR0915
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
""")
"""
)
print("LiteLLM_VerificationTokenView Created!") # noqa

View File

@ -14,74 +14,53 @@ from litellm.types.utils import StandardCallbackDynamicParams
class EnterpriseCallbackControls:
@staticmethod
def is_callback_disabled_dynamically(
callback: litellm.CALLBACK_TYPES,
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams,
) -> bool:
"""
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
Args:
callback: The callback to check (can be string, CustomLogger instance, or callable)
litellm_params: Parameters containing proxy server request info
Returns:
bool: True if the callback should be disabled, False otherwise
"""
from litellm.litellm_core_utils.custom_logger_registry import (
CustomLoggerRegistry,
)
try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(
litellm_params, standard_callback_dynamic_params
callback: litellm.CALLBACK_TYPES,
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams
) -> bool:
"""
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
Args:
callback: The callback to check (can be string, CustomLogger instance, or callable)
litellm_params: Parameters containing proxy server request info
Returns:
bool: True if the callback should be disabled, False otherwise
"""
from litellm.litellm_core_utils.custom_logger_registry import (
CustomLoggerRegistry,
)
verbose_logger.debug(
f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}"
)
verbose_logger.debug(
f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}"
)
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if (
not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling()
):
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(
f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
)
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = (
CustomLoggerRegistry.get_callback_str_from_class_type(
callback.__class__
)
)
if (
callback_str is not None
and callback_str.lower() in disabled_callbacks
):
verbose_logger.debug(
f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
)
return True
return False
except Exception as e:
verbose_logger.debug(f"Error checking disabled callbacks header: {str(e)}")
return False
try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling():
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__)
if callback_str is not None and callback_str.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
return False
except Exception as e:
verbose_logger.debug(
f"Error checking disabled callbacks header: {str(e)}"
)
return False
@staticmethod
def get_disabled_callbacks(
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams,
) -> Optional[List[str]]:
def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
"""
Get the disabled callbacks from the standard callback dynamic params.
"""
@ -92,24 +71,18 @@ class EnterpriseCallbackControls:
request_headers = get_proxy_server_request_headers(litellm_params)
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
if disabled_callbacks is not None:
disabled_callbacks = set(
[cb.strip().lower() for cb in disabled_callbacks.split(",")]
)
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
return list(disabled_callbacks)
#########################################################
# check if disabled via request body
#########################################################
if (
standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
is not None
):
return standard_callback_dynamic_params.get(
"litellm_disabled_callbacks", None
)
if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
return None
@staticmethod
def _should_allow_dynamic_callback_disabling():
import litellm
@ -117,14 +90,10 @@ class EnterpriseCallbackControls:
# Check if admin has disabled this feature
if litellm.allow_dynamic_callback_disabling is not True:
verbose_logger.debug(
"Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling"
)
verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling")
return False
if premium_user:
return True
verbose_logger.warning(
f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}"
)
return False
verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
return False

View File

@ -349,10 +349,8 @@ class BaseEmailLogger(CustomLogger):
)
# Calculate percentage and alert threshold
percentage = (
threshold_pct
if threshold_pct is not None
else int(EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100)
percentage = threshold_pct if threshold_pct is not None else int(
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100
)
threshold_fraction = percentage / 100.0
alert_threshold_str = (
@ -611,7 +609,9 @@ class BaseEmailLogger(CustomLogger):
continue
_id = user_info.token or user_info.user_id or "default_id"
_cache_key = f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
_cache_key = (
f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
)
result = await _cache.async_get_cache(key=_cache_key)
if result is not None:
@ -630,9 +630,7 @@ class BaseEmailLogger(CustomLogger):
continue
recipient_emails = list(set(emails))
event_message = (
f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
)
event_message = f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
webhook_event = WebhookEvent(
event="max_budget_alert",
event_message=event_message,

View File

@ -15,6 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
from .base_email import BaseEmailLogger
SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send"
@ -78,4 +79,4 @@ class SendGridEmailLogger(BaseEmailLogger):
verbose_logger.debug(
f"SendGrid response status={response.status_code}, body={response.text}"
)
return
return

View File

@ -1,7 +1,6 @@
"""
This is the litellm SMTP email integration
"""
import asyncio
from typing import List

View File

@ -1,7 +1,6 @@
"""
Enterprise specific logging utils
"""
from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata

View File

@ -153,11 +153,11 @@ async def get_audit_logs(
# Return paginated response
return PaginatedAuditLogResponse(
audit_logs=(
[AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs]
if audit_logs
else []
),
audit_logs=[
AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs
]
if audit_logs
else [],
total=total_count,
page=page,
page_size=page_size,

View File

@ -7,4 +7,4 @@ including custom SSO handlers and advanced authentication features.
from .custom_sso_handler import EnterpriseCustomSSOHandler
__all__ = ["EnterpriseCustomSSOHandler"]
__all__ = ["EnterpriseCustomSSOHandler"]

View File

@ -53,9 +53,7 @@ class CheckBatchCost:
"user_api_key_alias": getattr(user_row, "user_alias", None),
}
except Exception as e:
verbose_proxy_logger.error(
f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}"
)
verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}")
return {}
async def _cleanup_stale_managed_objects(self) -> None:
@ -64,22 +62,11 @@ class CheckBatchCost:
in non-terminal states as 'stale_expired'. These will never complete and
should not be polled.
"""
cutoff = datetime.now(timezone.utc) - timedelta(
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
)
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
result = await self.prisma_client.db.litellm_managedobjecttable.update_many(
where={
"file_purpose": "batch",
"status": {
"not_in": [
"completed",
"complete",
"failed",
"expired",
"cancelled",
"stale_expired",
]
},
"status": {"not_in": ["completed", "complete", "failed", "expired", "cancelled", "stale_expired"]},
"created_at": {"lt": cutoff},
},
data={"status": "stale_expired"},
@ -133,12 +120,9 @@ class CheckBatchCost:
try:
from litellm.integrations.prometheus import PrometheusLogger
prom_logger = PrometheusLogger.get_instance()
except Exception as e:
verbose_proxy_logger.error(
f"CheckBatchCost: could not get Prometheus logger: {e}"
)
verbose_proxy_logger.error(f"CheckBatchCost: could not get Prometheus logger: {e}")
prom_logger = None
processed_models: List[Tuple[Optional[str], Optional[str]]] = []
@ -177,11 +161,7 @@ class CheckBatchCost:
order={"created_at": "asc"},
)
except Exception as query_err:
if (
"batch_processed" not in str(query_err).lower()
and "unknown column" not in str(query_err).lower()
and "does not exist" not in str(query_err).lower()
):
if "batch_processed" not in str(query_err).lower() and "unknown column" not in str(query_err).lower() and "does not exist" not in str(query_err).lower():
raise
# Permanent schema gap — cache the result so future cycles skip straight to fallback
self._has_batch_processed_column = False
@ -236,13 +216,14 @@ class CheckBatchCost:
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
)
if prom_logger:
prom_logger.record_check_batch_cost_error(
"provider_retrieval_error"
)
prom_logger.record_check_batch_cost_error("provider_retrieval_error")
continue
## RETRIEVE THE BATCH JOB OUTPUT FILE
if response.status == "completed" and response.output_file_id is not None:
if (
response.status == "completed"
and response.output_file_id is not None
):
verbose_proxy_logger.info(
f"Batch ID: {batch_id} is complete, tracking cost and usage"
)
@ -269,25 +250,20 @@ class CheckBatchCost:
decoded = _is_base64_encoded_unified_file_id(raw_output_file_id)
if decoded:
try:
raw_output_file_id = decoded.split("llm_output_file_id,")[
1
].split(";")[0]
raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0]
except (IndexError, AttributeError):
pass
credentials = (
self.llm_router.get_deployment_credentials_with_provider(model_id)
or {}
)
credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {}
_file_content = await afile_content(
file_id=raw_output_file_id,
**credentials,
)
# Access content - handle both direct attribute and method call
if hasattr(_file_content, "content"):
if hasattr(_file_content, 'content'):
content_bytes = _file_content.content # type: ignore[union-attr]
elif hasattr(_file_content, "read"):
elif hasattr(_file_content, 'read'):
content_bytes = await _file_content.read() # type: ignore[misc]
else:
content_bytes = _file_content # type: ignore[assignment]
@ -314,9 +290,7 @@ class CheckBatchCost:
f"Skipping job {unified_object_id} because it is not a valid deployment info"
)
if prom_logger:
prom_logger.record_check_batch_cost_error(
"deployment_not_found"
)
prom_logger.record_check_batch_cost_error("deployment_not_found")
continue
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
litellm_model_name = deployment_info.litellm_params.model
@ -328,11 +302,7 @@ class CheckBatchCost:
# Pass deployment model_info so custom batch pricing
# (input_cost_per_token_batches etc.) is used for cost calc
deployment_model_info = (
deployment_info.model_info.model_dump()
if deployment_info.model_info
else {}
)
deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}
batch_cost, batch_usage, batch_models = (
await calculate_batch_cost_and_usage(
file_content_dictionary=file_content_as_dict,
@ -379,9 +349,7 @@ class CheckBatchCost:
# Record batch duration (completed_at - created_at)
if prom_logger and response.completed_at and response.created_at:
duration_seconds = float(
response.completed_at - response.created_at
)
duration_seconds = float(response.completed_at - response.created_at)
if duration_seconds >= 0:
prom_logger.record_managed_batch_duration(
duration_seconds=duration_seconds,
@ -390,9 +358,7 @@ class CheckBatchCost:
)
# Track this job for the final metrics summary
processed_models.append(
(model_name, str(llm_provider) if llm_provider else None)
)
processed_models.append((model_name, str(llm_provider) if llm_provider else None))
# mark the job as complete
try:

View File

@ -33,7 +33,9 @@ class CheckResponsesCost:
self.prisma_client: PrismaClient = prisma_client
self.llm_router: Router = llm_router
async def _expire_stale_rows(self, cutoff: datetime, batch_size: int) -> int:
async def _expire_stale_rows(
self, cutoff: datetime, batch_size: int
) -> int:
"""Execute the bounded UPDATE that marks stale rows as 'stale_expired'.
Isolated so it can be swapped / mocked in tests without touching the
@ -72,9 +74,7 @@ class CheckResponsesCost:
rows per invocation to avoid overwhelming the DB when there is a large
backlog.
"""
cutoff = datetime.now(timezone.utc) - timedelta(
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
)
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE)
if result > 0:
verbose_proxy_logger.warning(
@ -105,7 +105,7 @@ class CheckResponsesCost:
take=MAX_OBJECTS_PER_POLL_CYCLE,
order={"created_at": "asc"},
)
verbose_proxy_logger.debug(f"Found {len(jobs)} response jobs to check")
completed_jobs = []
@ -120,33 +120,29 @@ class CheckResponsesCost:
# Get the stored response object to extract model information
stored_response = job.file_object
model_name = stored_response.get("model", None)
# Decrypt the response ID
responses_id_security, _, _ = (
ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
)
responses_id_security, _, _ = ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
# Prepare metadata with model information for cost tracking
litellm_metadata = {
"user_api_key_user_id": job.created_by or "default-user-id",
}
# Add model information if available
if model_name:
litellm_metadata["model"] = model_name
litellm_metadata["model_group"] = (
model_name # Use same value for model_group
)
litellm_metadata["model_group"] = model_name # Use same value for model_group
response = await litellm.aget_responses(
response_id=responses_id_security,
litellm_metadata=litellm_metadata,
)
verbose_proxy_logger.debug(
f"Response {unified_object_id} status: {response.status}, model: {model_name}"
)
except Exception as e:
verbose_proxy_logger.info(
f"Skipping job {unified_object_id} due to error: {e}"
@ -159,7 +155,7 @@ class CheckResponsesCost:
f"Response {unified_object_id} is complete. Cost automatically tracked by aget_responses."
)
completed_jobs.append(job)
elif response.status in ["failed", "cancelled"]:
verbose_proxy_logger.info(
f"Response {unified_object_id} has status {response.status}, marking as complete"
@ -175,3 +171,4 @@ class CheckResponsesCost:
verbose_proxy_logger.info(
f"Marked {len(completed_jobs)} response jobs as completed"
)

View File

@ -41,7 +41,7 @@ class _PROXY_LiteLLMManagedVectorStores(
):
"""
Managed vector stores with target_model_names support.
This class provides functionality to:
- Create vector stores across multiple models
- Retrieve vector stores by unified ID
@ -77,14 +77,14 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> str:
"""
Generate the format string for the unified vector store ID.
Format:
litellm_proxy:vector_store;unified_id,<uuid>;target_model_names,<models>;resource_id,<vs_id>;model_id,<model_id>
"""
# VectorStoreCreateResponse is a TypedDict, so resource_object is a dictionary
# Extract provider resource ID from the response
provider_resource_id = resource_object.get("id", "")
# Model ID is stored in hidden params if the response object supports it
# For TypedDict responses, we need to check if _hidden_params was added
hidden_params: Dict[str, Any] = {}
@ -109,18 +109,20 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> VectorStoreCreateResponse:
"""
Create a vector store for a specific model.
Args:
llm_router: LiteLLM router instance
model: Model name to create vector store for
request_data: Request data for vector store creation
litellm_parent_otel_span: OpenTelemetry span for tracing
Returns:
VectorStoreCreateResponse from the provider
"""
# Use the router to create the vector store
response = await llm_router.avector_store_create(model=model, **request_data)
response = await llm_router.avector_store_create(
model=model, **request_data
)
return response
# ============================================================================
@ -137,14 +139,14 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> VectorStoreCreateResponse:
"""
Create a vector store across multiple models.
Args:
create_request: Vector store creation request parameters
llm_router: LiteLLM router instance
target_model_names_list: List of target model names
litellm_parent_otel_span: OpenTelemetry span for tracing
user_api_key_dict: User API key authentication details
Returns:
VectorStoreCreateResponse with unified ID
"""
@ -194,7 +196,7 @@ class _PROXY_LiteLLMManagedVectorStores(
# VectorStoreCreateResponse is a TypedDict, so we need to create a new dict with the unified ID
response = responses[0].copy()
response["id"] = unified_id
verbose_logger.info(
f"Successfully created managed vector store with unified ID: {unified_id}"
)
@ -210,13 +212,13 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> Dict[str, Any]:
"""
List vector stores created by a user.
Args:
user_api_key_dict: User API key authentication details
limit: Maximum number of vector stores to return
after: Cursor for pagination
order: Sort order ('asc' or 'desc')
Returns:
Dictionary with list of vector stores and pagination info
"""
@ -236,23 +238,23 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> bool:
"""
Check if user has access to a vector store.
Args:
vector_store_id: The unified vector store ID
user_api_key_dict: User API key authentication details
Returns:
True if user has access, False otherwise
"""
is_unified_id = is_base64_encoded_unified_id(vector_store_id)
if is_unified_id:
# Check access for managed vector store
return await self.can_user_access_unified_resource_id(
vector_store_id,
user_api_key_dict,
)
# Not a managed vector store, allow access
return True
@ -261,22 +263,24 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> bool:
"""
Check if user has access to a managed vector store in request data.
Args:
data: Request data containing vector_store_id
user_api_key_dict: User API key authentication details
Returns:
True if this is a managed vector store and user has access
Raises:
HTTPException: If user doesn't have access
"""
vector_store_id = cast(Optional[str], data.get("vector_store_id"))
is_unified_id = (
is_base64_encoded_unified_id(vector_store_id) if vector_store_id else False
is_base64_encoded_unified_id(vector_store_id)
if vector_store_id
else False
)
if is_unified_id and vector_store_id:
if await self.can_user_access_unified_resource_id(
vector_store_id, user_api_key_dict
@ -287,7 +291,7 @@ class _PROXY_LiteLLMManagedVectorStores(
status_code=403,
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
)
return False
# ============================================================================
@ -303,18 +307,18 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> Union[Exception, str, Dict, None]:
"""
Pre-call hook to handle vector store operations.
This hook intercepts vector store requests and:
- Validates access for managed vector stores
- Transforms unified IDs to provider-specific IDs
- Adds model routing information
Args:
user_api_key_dict: User API key authentication details
cache: Cache instance
data: Request data
call_type: Type of call being made
Returns:
Modified request data or None
"""
@ -326,40 +330,40 @@ class _PROXY_LiteLLMManagedVectorStores(
# Handle vector store search operations
if call_type == "avector_store_search":
vector_store_id = data.get("vector_store_id")
if vector_store_id:
# Check if it's a managed vector store ID
decoded_id = is_base64_encoded_unified_id(vector_store_id)
if decoded_id:
verbose_logger.debug(
f"Processing managed vector store search: {vector_store_id}"
)
# Check access
has_access = await self.can_user_access_unified_resource_id(
vector_store_id, user_api_key_dict
)
if not has_access:
raise HTTPException(
status_code=403,
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
)
# Parse the unified ID to extract components
parsed_id = parse_unified_id(vector_store_id)
if parsed_id:
# Extract the model ID and provider resource ID
model_id = parsed_id.get("model_id")
provider_resource_id = parsed_id.get("provider_resource_id")
target_model_names = parsed_id.get("target_model_names", [])
verbose_logger.debug(
f"Decoded vector store - model_id: {model_id}, provider_resource_id: {provider_resource_id}, target_model_names: {target_model_names}"
)
# Determine which model to use for routing
# Priority: model_id (deployment ID) > first target_model_name
routing_model = None
@ -367,28 +371,28 @@ class _PROXY_LiteLLMManagedVectorStores(
routing_model = model_id
elif target_model_names and len(target_model_names) > 0:
routing_model = target_model_names[0]
# Set the model for routing
if routing_model:
data["model"] = routing_model
verbose_logger.info(
f"Routing vector store search to model: {routing_model}"
)
# Replace the unified ID with the provider-specific ID
if provider_resource_id:
data["vector_store_id"] = provider_resource_id
verbose_logger.debug(
f"Replaced unified ID with provider resource ID: {provider_resource_id}"
)
# Handle vector store retrieve/delete operations
elif call_type in ("avector_store_retrieve", "avector_store_delete"):
await self.check_managed_vector_store_access(data, user_api_key_dict)
# If it's a managed vector store, we'll handle it in the endpoint
# No need to transform here as the endpoint will route to the hook
return data
# ============================================================================
@ -403,15 +407,15 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> Any:
"""
Post-call hook to transform responses.
This hook can be used to transform responses if needed.
For now, it just passes through the response.
Args:
data: Request data
user_api_key_dict: User API key authentication details
response: Response from the provider
Returns:
Potentially modified response
"""
@ -432,21 +436,21 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> List[Dict]:
"""
Filter deployments based on vector store availability.
This is used by the router to select only deployments that have
the vector store available.
Note: This method signature is a compromise between CustomLogger and BaseManagedResource
parent classes which have incompatible signatures. The type: ignore[override] is necessary
due to this multiple inheritance conflict.
Args:
model: Model name
healthy_deployments: List of healthy deployments
messages: Messages (unused for vector stores, required by CustomLogger interface)
request_kwargs: Request kwargs containing vector_store_id and mappings
parent_otel_span: OpenTelemetry span for tracing
Returns:
Filtered list of deployments
"""

View File

@ -2,6 +2,7 @@
Enterprise internal user management endpoints
"""
from fastapi import APIRouter, Depends, HTTPException
from litellm.proxy._types import UserAPIKeyAuth

View File

@ -147,12 +147,12 @@ async def list_vector_stores(
vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db(
prisma_client=prisma_client
)
# Also clean up in-memory registry to remove any deleted vector stores
if litellm.vector_store_registry is not None:
db_vector_store_ids = {
vs.get("vector_store_id")
for vs in vector_stores_from_db
vs.get("vector_store_id")
for vs in vector_stores_from_db
if vs.get("vector_store_id")
}
# Remove any in-memory vector stores that no longer exist in database

View File

@ -39,23 +39,15 @@ class EmailEvent(str, enum.Enum):
soft_budget_crossed = "Soft Budget Crossed"
max_budget_alert = "Max Budget Alert"
class EmailEventSettings(BaseModel):
event: EmailEvent
enabled: bool
class EmailEventSettingsUpdateRequest(BaseModel):
settings: List[EmailEventSettings]
class EmailEventSettingsResponse(BaseModel):
settings: List[EmailEventSettings]
class DefaultEmailSettings(BaseModel):
"""Default settings for email events"""
settings: Dict[EmailEvent, bool] = Field(
default_factory=lambda: {
EmailEvent.virtual_key_created: True, # On by default
@ -65,12 +57,10 @@ class DefaultEmailSettings(BaseModel):
EmailEvent.max_budget_alert: True, # On by default
}
)
def to_dict(self) -> Dict[str, bool]:
"""Convert to dictionary with string keys for storage"""
return {event.value: enabled for event, enabled in self.settings.items()}
@classmethod
def get_defaults(cls) -> Dict[str, bool]:
"""Get the default settings as a dictionary with string keys"""
return cls().to_dict()
return cls().to_dict()

View File

@ -6,6 +6,7 @@ Always uses fastuuid for performance.
import fastuuid as _uuid # type: ignore
# Expose a module-like alias so callers can use: uuid.uuid4()
uuid = _uuid

View File

@ -9,6 +9,7 @@ from typing import Dict, Optional
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
# HTTP status code -> Anthropic error type
# Source: https://docs.anthropic.com/en/api/errors
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {

View File

@ -2,6 +2,7 @@
from typing_extensions import Literal, Required, TypedDict
# Known Anthropic error types
# Source: https://docs.anthropic.com/en/api/errors
AnthropicErrorType = Literal[

View File

@ -97,7 +97,7 @@ def _build_reasoning_item(
def _reasoning_item_to_response_input(
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]],
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]]
) -> Dict[str, Any]:
"""Convert a stored ChatCompletionReasoningItem back to a Responses API input item."""
r_input: Dict[str, Any] = {

View File

@ -5,6 +5,7 @@ Auto-detect content type per message: code, JSON, or text.
import json
import re
_CODE_KEYWORDS = re.compile(
r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b"
)

View File

@ -1,5 +1,6 @@
from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union
FileContentProvider = Literal[
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
]

View File

@ -1,10 +1,10 @@
"""
Google GenAI Adapters for LiteLLM
This module provides adapters for transforming Google GenAI generate_content requests
This module provides adapters for transforming Google GenAI generate_content requests
to/from LiteLLM completion format with full support for:
- Text content transformation
- Tool calling (function declarations, function calls, function responses)
- Tool calling (function declarations, function calls, function responses)
- Streaming (both regular and tool calling)
- Mixed content (text + tool calls)
"""

View File

@ -1,9 +1,9 @@
"""
Handles Batching + sending Httpx Post requests to slack
Handles Batching + sending Httpx Post requests to slack
Slack alerts are sent every 10s or when events are greater than X events
Slack alerts are sent every 10s or when events are greater than X events
see custom_batch_logger.py for more details / defaults
see custom_batch_logger.py for more details / defaults
"""
from typing import TYPE_CHECKING, Any

View File

@ -18,7 +18,7 @@ else:
def process_slack_alerting_variables(
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]],
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
"""
process alert_to_webhook_url

View File

@ -1,5 +1,5 @@
"""
Base class for Additional Logging Utils for CustomLoggers
Base class for Additional Logging Utils for CustomLoggers
- Health Check for the logging util
- Get Request / Response Payload for the logging util

View File

@ -1,5 +1,5 @@
"""
Custom Logger that handles batching logic
Custom Logger that handles batching logic
Use this if you want your logs to be stored in memory and flushed periodically.
"""

View File

@ -9,6 +9,7 @@ import polars as pl
from .schema import FOCUS_NORMALIZED_SCHEMA
_TAG_KEYS = (
"team_id",
"team_alias",

View File

@ -105,7 +105,7 @@ def _remove_nulls(x: Dict[str, Any]) -> Dict[str, Any]:
def get_traces_and_spans_from_payload(
payload: List[Dict[str, Any]],
payload: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Separate traces and spans from payload.

View File

@ -1,8 +1,8 @@
"""
s3 Bucket Logging Integration
async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
"""

View File

@ -5,28 +5,28 @@ This module provides SDK methods for Google's Interactions API.
Usage:
import litellm
# Create an interaction with a model
response = litellm.interactions.create(
model="gemini-2.5-flash",
input="Hello, how are you?"
)
# Create an interaction with an agent
response = litellm.interactions.create(
agent="deep-research-pro-preview-12-2025",
input="Research the current state of cancer research"
)
# Async version
response = await litellm.interactions.acreate(...)
# Get an interaction
response = litellm.interactions.get(interaction_id="...")
# Delete an interaction
result = litellm.interactions.delete(interaction_id="...")
# Cancel an interaction
result = litellm.interactions.cancel(interaction_id="...")

View File

@ -8,25 +8,25 @@ Per OpenAPI spec (https://ai.google.dev/static/api/interactions.openapi.json):
Usage:
import litellm
# Create an interaction with a model
response = litellm.interactions.create(
model="gemini-2.5-flash",
input="Hello, how are you?"
)
# Create an interaction with an agent
response = litellm.interactions.create(
agent="deep-research-pro-preview-12-2025",
input="Research the current state of cancer research"
)
# Async version
response = await litellm.interactions.acreate(...)
# Get an interaction
response = litellm.interactions.get(interaction_id="...")
# Delete an interaction
result = litellm.interactions.delete(interaction_id="...")
"""

View File

@ -994,8 +994,10 @@ class Logging(LiteLLMLoggingBaseClass):
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata["raw_request"] = "redacted by litellm. \
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
@ -1029,8 +1031,12 @@ class Logging(LiteLLMLoggingBaseClass):
error=str(e),
)
)
_metadata["raw_request"] = "Unable to Log \
raw request: {}".format(str(e))
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
)
)
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
try:
self.logger_fn(

View File

@ -5533,7 +5533,9 @@ def default_response_schema_prompt(response_schema: dict) -> str:
prompt_str = """Use this JSON schema:
```json
{}
```""".format(response_schema)
```""".format(
response_schema
)
return prompt_str

View File

@ -1,9 +1,9 @@
"""
This is a cache for LangfuseLoggers.
Langfuse Python SDK initializes a thread for each client.
Langfuse Python SDK initializes a thread for each client.
This ensures we do
This ensures we do
1. Proper cleanup of Langfuse initialized clients.
2. Re-use created langfuse clients.
"""

View File

@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, cast
from litellm._logging import verbose_logger
# ---------------------------------------------------------------------------
# SSE parsing helpers (module-level to keep the class lean)
# ---------------------------------------------------------------------------

View File

@ -4,10 +4,10 @@ Support for o1 and o3 model families
https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param)
- Temperature => drop param (if user opts in to dropping param)
"""

View File

@ -1,5 +1,5 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
Why separate file? Make it easy to see how transformation works

View File

@ -1,5 +1,5 @@
"""
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
"""
from typing import Optional

View File

@ -1,5 +1,5 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
Why separate file? Make it easy to see how transformation works

View File

@ -1,5 +1,5 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
Why separate file? Make it easy to see how transformation works
"""

View File

@ -16,6 +16,7 @@ from litellm.secret_managers.main import get_secret_str
from ...openai_like.chat.transformation import OpenAILikeChatConfig
BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1"

View File

@ -1,5 +1,5 @@
"""
Legacy /v1/embedding handler for Bedrock Cohere.
Legacy /v1/embedding handler for Bedrock Cohere.
"""
import json

View File

@ -13,6 +13,7 @@ from typing import Tuple
import httpx
# ---------------------------------------------------------------------------
# Pre-built response templates
# ---------------------------------------------------------------------------

View File

@ -1,5 +1,5 @@
"""
Cost calculator for Dashscope Chat models.
Cost calculator for Dashscope Chat models.
Handles tiered pricing and prompt caching scenarios.
"""

View File

@ -1,5 +1,5 @@
"""
Support for OpenAI's `/v1/chat/completions` endpoint.
Support for OpenAI's `/v1/chat/completions` endpoint.
Calls done in OpenAI/openai.py as DataRobot is openai-compatible.
"""

View File

@ -1,5 +1,5 @@
"""
Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format.
Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format.
"""
from typing import Any, Dict, List, Optional, Union

View File

@ -1,5 +1,5 @@
"""
Cost calculator for DeepSeek Chat models.
Cost calculator for DeepSeek Chat models.
Handles prompt caching scenario.
"""

View File

@ -22,6 +22,7 @@ from litellm.types.utils import all_litellm_params
from ..common_utils import ElevenLabsException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import HttpxBinaryResponseContent

View File

@ -55,7 +55,7 @@ def _convert_image_to_gemini_format(image_file) -> Dict[str, str]:
def _usage_video_resolution_from_parameters(
parameters: Dict[str, Any],
parameters: Dict[str, Any]
) -> Optional[str]:
"""Normalize Veo ``parameters.resolution`` for usage and cost tracking."""
res = parameters.get("resolution")

View File

@ -1,5 +1,5 @@
"""
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""

View File

@ -1,5 +1,5 @@
"""
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works

View File

@ -1,5 +1,5 @@
"""
Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format.
Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format.
Why separate file? Make it easy to see how transformation works

View File

@ -1,5 +1,5 @@
"""
Support for OpenAI's `/v1/chat/completions` endpoint.
Support for OpenAI's `/v1/chat/completions` endpoint.
Calls done in OpenAI/openai.py as Novita AI is openai-compatible.

View File

@ -1,7 +1,7 @@
"""
Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
This is OpenAI compatible
This is OpenAI compatible
This file only contains param mapping logic

View File

@ -1,7 +1,7 @@
"""
Nvidia NIM embeddings endpoint: https://docs.api.nvidia.com/nim/reference/nvidia-nv-embedqa-e5-v5-infer
This is OpenAI compatible
This is OpenAI compatible
This file only contains param mapping logic

View File

@ -1,14 +1,14 @@
"""
Support for o1/o3 model family
Support for o1/o3 model family
https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param)
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param)
"""
from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload

View File

@ -201,7 +201,7 @@ class BaseOpenAILLM:
@staticmethod
def get_openai_client_initialization_param_fields(
client_type: Literal["openai", "azure"],
client_type: Literal["openai", "azure"]
) -> Tuple[str, ...]:
"""Returns a tuple of fields that are used to initialize the OpenAI client"""
if client_type == "openai":

View File

@ -49,6 +49,7 @@ from litellm.types.utils import (
)
from litellm.llms.openrouter.common_utils import OpenRouterException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:

View File

@ -1,7 +1,7 @@
"""
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`
In the Huggingface TGI format.
In the Huggingface TGI format.
"""
import json

View File

@ -1,7 +1,7 @@
"""
Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke`
In the Huggingface TGI format.
In the Huggingface TGI format.
"""
from typing import TYPE_CHECKING, Any, List, Optional, Union

View File

@ -207,7 +207,7 @@ def resolve_resource_group(sources: List[Source]) -> Optional[str]:
def _parse_service_key_once(
service_key: Optional[Union[str, dict]],
service_key: Optional[Union[str, dict]]
) -> Optional[Dict[str, Any]]:
"""
Pre-parse service_key if it's a string to avoid repeated JSON parsing.

View File

@ -14,6 +14,7 @@ from ...openai_like.chat.transformation import OpenAIGPTConfig
from ..utils import SnowflakeBaseConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

View File

@ -1,5 +1,5 @@
"""
Support for OpenAI's `/v1/chat/completions` endpoint.
Support for OpenAI's `/v1/chat/completions` endpoint.
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.

View File

@ -1,5 +1,5 @@
"""
Support for OpenAI's `/v1/embeddings` endpoint.
Support for OpenAI's `/v1/embeddings` endpoint.
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.

View File

@ -1,5 +1,5 @@
"""
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""

View File

@ -1,5 +1,5 @@
"""
Transformation logic for context caching.
Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works
"""
@ -19,7 +19,7 @@ from ..gemini.transformation import (
def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]], # (idx, message)
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int:
"""
Find the array index that ends the first continuous sequence of message blocks.

View File

@ -632,14 +632,16 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
contents.append(ContentType(role="user", parts=tool_call_responses))
if len(contents) == 0:
verbose_logger.warning("""
verbose_logger.warning(
"""
No contents in messages. Contents are required. See
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
If the original request did not comply to OpenAI API requirements it should have failed by now,
but LiteLLM does not check for missing messages.
Setting an empty content to prevent an 400 error.
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
""")
"""
)
contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
return contents
except Exception as e:

View File

@ -1,5 +1,5 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
Why separate file? Make it easy to see how transformation works
"""

View File

@ -139,7 +139,7 @@ class VertexTextToSpeechAPI(VertexLLM):
########## End of logging ############
####### Send the request ###################
if _is_async is True:
return self.async_audio_speech( # type: ignore
return self.async_audio_speech( # type:ignore
logging_obj=logging_obj, url=url, headers=headers, request=request
)
sync_handler = _get_httpx_client()

View File

@ -1,5 +1,5 @@
"""
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
"""

View File

@ -1,6 +1,6 @@
"""
This module is used to transform the request and response for the Voyage contextualized embeddings API.
This would be used for all the contextualized embeddings models in Voyage.
This module is used to transform the request and response for the Voyage contextualized embeddings API.
This would be used for all the contextualized embeddings models in Voyage.
"""
from typing import List, Optional, Union

View File

@ -324,7 +324,7 @@ class CustomOpenAPISpec:
@staticmethod
def add_chat_completion_request_schema(
openapi_schema: Dict[str, Any],
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
@ -380,7 +380,7 @@ class CustomOpenAPISpec:
@staticmethod
def add_responses_api_request_schema(
openapi_schema: Dict[str, Any],
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
@ -410,7 +410,7 @@ class CustomOpenAPISpec:
@staticmethod
def add_llm_api_request_schema_body(
openapi_schema: Dict[str, Any],
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add LLM API request schema bodies to OpenAPI specification for documentation.

View File

@ -257,7 +257,7 @@ async def get_form_data(request: Request) -> Dict[str, Any]:
async def convert_upload_files_to_file_data(
form_data: Dict[str, Any],
form_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Convert FastAPI UploadFile objects to file data tuples for litellm.

View File

@ -1,5 +1,5 @@
"""
Contains utils used by OpenAI compatible endpoints
Contains utils used by OpenAI compatible endpoints
"""
from typing import Optional, Set

View File

@ -1,5 +1,5 @@
"""
What is this?
What is this?
CRUD endpoints for managing pass-through endpoints
"""

View File

@ -34,7 +34,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
# If an error occurs, the view does not exist, so create it
await db.execute_raw("""
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
@ -46,7 +47,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
""")
"""
)
verbose_logger.debug("LiteLLM_VerificationTokenView Created!")

View File

@ -10,6 +10,7 @@ every text fragment.
from typing import Any, Callable, Dict, FrozenSet, Iterator, List
# Call types whose body carries free-form chat / prompt text that
# text-content guardrails (banned keywords, content moderation, secret
# detection, …) should inspect. The proxy ingress passes ``route_type``

View File

@ -4,6 +4,7 @@ from litellm.types.guardrails import SupportedGuardrailIntegrations
from .akto import AktoGuardrail
if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams

View File

@ -6,7 +6,7 @@ The actual skill logic is in litellm/llms/litellm_proxy/skills/.
Usage:
from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook
# Register hook in proxy
litellm.callbacks.append(SkillsInjectionHook())
"""

View File

@ -1,9 +1,9 @@
"""
BUDGET MANAGEMENT
All /budget management endpoints
All /budget management endpoints
/budget/new
/budget/new
/budget/info
/budget/update
/budget/delete

View File

@ -1,9 +1,9 @@
"""
CUSTOMER MANAGEMENT
All /customer management endpoints
All /customer management endpoints
/customer/new
/customer/new
/customer/info
/customer/update
/customer/delete

View File

@ -529,7 +529,7 @@ async def _update_existing_team_model_assignment(
"""
def _get_team_public_model_name(
model_info: Optional[Union[dict, str]],
model_info: Optional[Union[dict, str]]
) -> Optional[str]:
if isinstance(model_info, dict):
value = model_info.get("team_public_model_name")

View File

@ -7,7 +7,7 @@ variables.
Environment Variables:
- MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
- MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL
If these are not set, the default Microsoft endpoints are used.

View File

@ -4347,7 +4347,9 @@ async def list_team(
except Exception as e:
team_exception = """Invalid team object for team_id: {}. team_object={}.
Error: {}
""".format(team.team_id, team.model_dump(), str(e))
""".format(
team.team_id, team.model_dump(), str(e)
)
verbose_proxy_logger.exception(team_exception)
continue
# Sort the responses by team_alias

View File

@ -3,7 +3,7 @@ User Agent Analytics Endpoints
This module provides optimized endpoints for tracking user agent activity metrics including:
- Daily Active Users (DAU) by tags for configurable number of days
- Weekly Active Users (WAU) by tags for configurable number of weeks
- Weekly Active Users (WAU) by tags for configurable number of weeks
- Monthly Active Users (MAU) by tags for configurable number of months
- Summary analytics by tags

View File

@ -18,6 +18,7 @@ from litellm.litellm_core_utils.litellm_logging import (
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import StandardPassThroughResponseObject
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
"POST /v0/agents": "cursor:agent:create",
"GET /v0/agents": "cursor:agent:list",

View File

@ -292,7 +292,9 @@ class ProxyInitializationHelpers:
_endpoint_str = (
f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\"
)
curl_command = _endpoint_str + """
curl_command = (
_endpoint_str
+ """
--header 'Content-Type: application/json' \\
--data ' {
"model": "gpt-3.5-turbo",
@ -305,6 +307,7 @@ class ProxyInitializationHelpers:
}'
\n
"""
)
print() # noqa
print( # noqa
'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n'
@ -380,9 +383,11 @@ class ProxyInitializationHelpers:
with open(os.devnull, "w") as devnull:
subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e:
print(f"""
print( # noqa
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") # noqa # noqa
"""
) # noqa
@staticmethod
def _is_port_in_use(port):

View File

@ -2688,9 +2688,11 @@ def run_ollama_serve():
with open(os.devnull, "w") as devnull:
subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e:
verbose_proxy_logger.debug(f"""
verbose_proxy_logger.debug(
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""")
"""
)
def _get_process_rss_mb() -> Optional[float]:

View File

@ -3184,14 +3184,16 @@ async def provider_budgets() -> ProviderBudgetResponse:
async def get_spend_by_tags(
prisma_client: PrismaClient, start_date=None, end_date=None
):
response = await prisma_client.db.query_raw("""
response = await prisma_client.db.query_raw(
"""
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
GROUP BY individual_request_tag;
""")
"""
)
return response

View File

@ -2712,7 +2712,8 @@ class PrismaClient:
required_view = "LiteLLM_VerificationTokenView"
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
ret = await self.db.query_raw(f"""
ret = await self.db.query_raw(
f"""
WITH existing_views AS (
SELECT viewname
FROM pg_views
@ -2724,7 +2725,8 @@ class PrismaClient:
(SELECT COUNT(*) FROM existing_views) AS view_count,
ARRAY_AGG(viewname) AS view_names
FROM existing_views
""")
"""
)
expected_total_views = len(expected_views)
if ret[0]["view_count"] == expected_total_views:
verbose_proxy_logger.info("All necessary views exist!")
@ -2733,7 +2735,8 @@ class PrismaClient:
## check if required view exists ##
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
await self.health_check() # make sure we can connect to db
await self.db.execute_raw("""
await self.db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
@ -2743,7 +2746,8 @@ class PrismaClient:
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
""")
"""
)
verbose_proxy_logger.info(
"LiteLLM_VerificationTokenView Created in DB!"

View File

@ -1,5 +1,5 @@
"""
What is this?
What is this?
Logging Pass-Through Endpoints
"""

View File

@ -826,7 +826,7 @@ class Router:
@staticmethod
def _normalize_strategy(
strategy: Union[RoutingStrategy, str, None],
strategy: Union[RoutingStrategy, str, None]
) -> Optional[str]:
if strategy is None:
return None

View File

@ -103,7 +103,7 @@ def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str
def _recent_tool_results(
messages: Optional[List[Dict[str, Any]]],
messages: Optional[List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""Extract the current turn's tool result payloads from the request messages.

Some files were not shown because too many files have changed in this diff Show More