Drop dep bumps + black-26 reformat to clear fork CI policy

PR was blocked by .github/workflows/guard-fork-dependencies.yml: fork PRs
cannot modify uv.lock. Reverting:

- uv.lock + pyproject.toml black bump (24.10.0 -> 26.3.1) and the 295
  files of mechanical Black 26 reformat coupled to it
- pyproject.toml diskcache extra change (kept the runtime mitigation in
  litellm/caching/disk_cache.py via JSONDisk)

Kept:
- Dockerfile cache narrowing (drops ~660 MB of uv build cache that
  surfaced cached setuptools as CVE findings)
- litellm/caching/disk_cache.py: dc.JSONDisk to neutralize CVE-2025-69872
- ui/litellm-dashboard/package-lock.json + litellm-js/spend-logs/package-lock.json:
  next/postcss/hono/uuid CVE bumps (these are not blocked by the fork guard)
- tests/test_litellm/caching/test_disk_cache.py
- tests/code_coverage_tests/liccheck.ini: harmless black authorization

Black + gitpython + langchain dep upgrades will need a follow-up from a
maintainer pushing a branch in the canonical BerriAI/litellm repo.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
user 2026-05-07 23:04:52 +00:00
parent 63bda3f001
commit 5bafa8b3a2
No known key found for this signature in database
292 changed files with 1814 additions and 2305 deletions

View File

@ -2,7 +2,6 @@ import asyncio
import aiohttp import aiohttp
import json import json
# Asynchronously fetch data from a given URL # Asynchronously fetch data from a given URL
async def fetch_data(url): async def fetch_data(url):
try: try:
@ -16,24 +15,22 @@ async def fetch_data(url):
resp_json = await resp.json() resp_json = await resp.json()
print("Fetch the data from URL.") print("Fetch the data from URL.")
# Return the 'data' field from the JSON response # Return the 'data' field from the JSON response
return resp_json["data"] return resp_json['data']
except Exception as e: except Exception as e:
# Print an error message if fetching data fails # Print an error message if fetching data fails
print("Error fetching data from URL:", e) print("Error fetching data from URL:", e)
return None return None
# Synchronize local data with remote data # Synchronize local data with remote data
def sync_local_data_with_remote(local_data, remote_data): def sync_local_data_with_remote(local_data, remote_data):
# Update existing keys in local_data with values from remote_data # Update existing keys in local_data with values from remote_data
for key in set(local_data) & set(remote_data): for key in (set(local_data) & set(remote_data)):
local_data[key].update(remote_data[key]) local_data[key].update(remote_data[key])
# Add new keys from remote_data to local_data # Add new keys from remote_data to local_data
for key in set(remote_data) - set(local_data): for key in (set(remote_data) - set(local_data)):
local_data[key] = remote_data[key] local_data[key] = remote_data[key]
# Write data to the json file # Write data to the json file
def write_to_file(file_path, data): def write_to_file(file_path, data):
try: try:
@ -46,7 +43,6 @@ def write_to_file(file_path, data):
# Print an error message if writing to file fails # Print an error message if writing to file fails
print("Error updating JSON file:", e) print("Error updating JSON file:", e)
# Update the existing models and add the missing models for OpenRouter # Update the existing models and add the missing models for OpenRouter
def transform_openrouter_data(data): def transform_openrouter_data(data):
transformed = {} transformed = {}
@ -58,41 +54,33 @@ def transform_openrouter_data(data):
} }
# Add 'max_output_tokens' as a field if it is not None # Add 'max_output_tokens' as a field if it is not None
if ( if "top_provider" in row and "max_completion_tokens" in row["top_provider"] and row["top_provider"]["max_completion_tokens"] is not None:
"top_provider" in row obj['max_output_tokens'] = int(row["top_provider"]["max_completion_tokens"])
and "max_completion_tokens" in row["top_provider"]
and row["top_provider"]["max_completion_tokens"] is not None
):
obj["max_output_tokens"] = int(row["top_provider"]["max_completion_tokens"])
# Add the field 'output_cost_per_token' # Add the field 'output_cost_per_token'
obj.update( obj.update({
{ "output_cost_per_token": float(row["pricing"]["completion"]),
"output_cost_per_token": float(row["pricing"]["completion"]), })
}
)
# Add field 'input_cost_per_image' if it exists and is non-zero # Add field 'input_cost_per_image' if it exists and is non-zero
if ( if "pricing" in row and "image" in row["pricing"] and float(row["pricing"]["image"]) != 0.0:
"pricing" in row obj['input_cost_per_image'] = float(row["pricing"]["image"])
and "image" in row["pricing"]
and float(row["pricing"]["image"]) != 0.0
):
obj["input_cost_per_image"] = float(row["pricing"]["image"])
# Add the fields 'litellm_provider' and 'mode' # Add the fields 'litellm_provider' and 'mode'
obj.update({"litellm_provider": "openrouter", "mode": "chat"}) obj.update({
"litellm_provider": "openrouter",
"mode": "chat"
})
# Add the 'supports_vision' field if the modality is 'multimodal' # Add the 'supports_vision' field if the modality is 'multimodal'
if row.get("architecture", {}).get("modality") == "multimodal": if row.get('architecture', {}).get('modality') == 'multimodal':
obj["supports_vision"] = True obj['supports_vision'] = True
# Use a composite key to store the transformed object # Use a composite key to store the transformed object
transformed[f'openrouter/{row["id"]}'] = obj transformed[f'openrouter/{row["id"]}'] = obj
return transformed return transformed
# Update the existing models and add the missing models for Vercel AI Gateway # Update the existing models and add the missing models for Vercel AI Gateway
def transform_vercel_ai_gateway_data(data): def transform_vercel_ai_gateway_data(data):
transformed = {} transformed = {}
@ -101,30 +89,20 @@ def transform_vercel_ai_gateway_data(data):
"max_tokens": row["context_window"], "max_tokens": row["context_window"],
"input_cost_per_token": float(row["pricing"]["input"]), "input_cost_per_token": float(row["pricing"]["input"]),
"output_cost_per_token": float(row["pricing"]["output"]), "output_cost_per_token": float(row["pricing"]["output"]),
"max_output_tokens": row["max_tokens"], 'max_output_tokens': row['max_tokens'],
"max_input_tokens": row["context_window"], 'max_input_tokens': row["context_window"],
} }
# Handle cache pricing if available # Handle cache pricing if available
if "pricing" in row: if "pricing" in row:
if ( if "input_cache_read" in row["pricing"] and row["pricing"]["input_cache_read"] is not None:
"input_cache_read" in row["pricing"] obj['cache_read_input_token_cost'] = float(f"{float(row['pricing']['input_cache_read']):e}")
and row["pricing"]["input_cache_read"] is not None
): if "input_cache_write" in row["pricing"] and row["pricing"]["input_cache_write"] is not None:
obj["cache_read_input_token_cost"] = float( obj['cache_creation_input_token_cost'] = float(f"{float(row['pricing']['input_cache_write']):e}")
f"{float(row['pricing']['input_cache_read']):e}"
)
if (
"input_cache_write" in row["pricing"]
and row["pricing"]["input_cache_write"] is not None
):
obj["cache_creation_input_token_cost"] = float(
f"{float(row['pricing']['input_cache_write']):e}"
)
mode = "embedding" if "embedding" in row["id"].lower() else "chat" mode = "embedding" if "embedding" in row["id"].lower() else "chat"
obj.update({"litellm_provider": "vercel_ai_gateway", "mode": mode}) obj.update({"litellm_provider": "vercel_ai_gateway", "mode": mode})
transformed[f'vercel_ai_gateway/{row["id"]}'] = obj transformed[f'vercel_ai_gateway/{row["id"]}'] = obj
@ -148,31 +126,24 @@ def load_local_data(file_path):
print("Error decoding JSON:", e) print("Error decoding JSON:", e)
return None return None
def main(): def main():
local_file_path = ( local_file_path = "model_prices_and_context_window.json" # Path to the local data file
"model_prices_and_context_window.json" # Path to the local data file openrouter_url = "https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
) vercel_ai_gateway_url = "https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
openrouter_url = (
"https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
)
vercel_ai_gateway_url = (
"https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
)
# Load local data from file # Load local data from file
local_data = load_local_data(local_file_path) local_data = load_local_data(local_file_path)
# Fetch OpenRouter data # Fetch OpenRouter data
openrouter_data = asyncio.run(fetch_data(openrouter_url)) openrouter_data = asyncio.run(fetch_data(openrouter_url))
# Transform the fetched OpenRouter data # Transform the fetched OpenRouter data
openrouter_data = transform_openrouter_data(openrouter_data) openrouter_data = transform_openrouter_data(openrouter_data)
# Fetch Vercel AI Gateway data # Fetch Vercel AI Gateway data
vercel_data = asyncio.run(fetch_data(vercel_ai_gateway_url)) vercel_data = asyncio.run(fetch_data(vercel_ai_gateway_url))
# Transform the fetched Vercel AI Gateway data # Transform the fetched Vercel AI Gateway data
vercel_data = transform_vercel_ai_gateway_data(vercel_data) vercel_data = transform_vercel_ai_gateway_data(vercel_data)
# Combine both datasets # Combine both datasets
all_remote_data = {**openrouter_data, **vercel_data} all_remote_data = {**openrouter_data, **vercel_data}
@ -183,7 +154,6 @@ def main():
else: else:
print("Failed to fetch model data from either local file or URL.") print("Failed to fetch model data from either local file or URL.")
# Entry point of the script # Entry point of the script
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -16,75 +16,64 @@ from pathlib import Path
import json import json
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
# ANSI color codes for terminal output # ANSI color codes for terminal output
class Colors: class Colors:
GREEN = "\033[92m" GREEN = '\033[92m'
RED = "\033[91m" RED = '\033[91m'
YELLOW = "\033[93m" YELLOW = '\033[93m'
BLUE = "\033[94m" BLUE = '\033[94m'
PURPLE = "\033[95m" PURPLE = '\033[95m'
CYAN = "\033[96m" CYAN = '\033[96m'
RESET = "\033[0m" RESET = '\033[0m'
BOLD = "\033[1m" BOLD = '\033[1m'
def print_colored(message: str, color: str = Colors.RESET): def print_colored(message: str, color: str = Colors.RESET):
"""Print colored message to terminal""" """Print colored message to terminal"""
print(f"{color}{message}{Colors.RESET}") print(f"{color}{message}{Colors.RESET}")
def get_provider_from_test_file(test_file: str) -> str: def get_provider_from_test_file(test_file: str) -> str:
"""Map test file names to provider names""" """Map test file names to provider names"""
provider_mapping = { provider_mapping = {
"test_anthropic": "Anthropic", 'test_anthropic': 'Anthropic',
"test_azure": "Azure", 'test_azure': 'Azure',
"test_bedrock": "AWS Bedrock", 'test_bedrock': 'AWS Bedrock',
"test_openai": "OpenAI", 'test_openai': 'OpenAI',
"test_vertex": "Google Vertex AI", 'test_vertex': 'Google Vertex AI',
"test_gemini": "Google Vertex AI", 'test_gemini': 'Google Vertex AI',
"test_cohere": "Cohere", 'test_cohere': 'Cohere',
"test_databricks": "Databricks", 'test_databricks': 'Databricks',
"test_groq": "Groq", 'test_groq': 'Groq',
"test_together": "Together AI", 'test_together': 'Together AI',
"test_mistral": "Mistral", 'test_mistral': 'Mistral',
"test_deepseek": "DeepSeek", 'test_deepseek': 'DeepSeek',
"test_replicate": "Replicate", 'test_replicate': 'Replicate',
"test_huggingface": "HuggingFace", 'test_huggingface': 'HuggingFace',
"test_fireworks": "Fireworks AI", 'test_fireworks': 'Fireworks AI',
"test_perplexity": "Perplexity", 'test_perplexity': 'Perplexity',
"test_cloudflare": "Cloudflare", 'test_cloudflare': 'Cloudflare',
"test_voyage": "Voyage AI", 'test_voyage': 'Voyage AI',
"test_xai": "xAI", 'test_xai': 'xAI',
"test_nvidia": "NVIDIA", 'test_nvidia': 'NVIDIA',
"test_watsonx": "IBM watsonx", 'test_watsonx': 'IBM watsonx',
"test_azure_ai": "Azure AI", 'test_azure_ai': 'Azure AI',
"test_snowflake": "Snowflake", 'test_snowflake': 'Snowflake',
"test_infinity": "Infinity", 'test_infinity': 'Infinity',
"test_jina": "Jina AI", 'test_jina': 'Jina AI',
"test_deepgram": "Deepgram", 'test_deepgram': 'Deepgram',
"test_clarifai": "Clarifai", 'test_clarifai': 'Clarifai',
"test_triton": "Triton", 'test_triton': 'Triton',
} }
for key, provider in provider_mapping.items(): for key, provider in provider_mapping.items():
if key in test_file: if key in test_file:
return provider return provider
# For cross-provider test files # For cross-provider test files
if any( if any(name in test_file for name in ['test_optional_params', 'test_prompt_factory',
name in test_file 'test_router', 'test_text_completion']):
for name in [ return f'Cross-Provider Tests ({test_file})'
"test_optional_params",
"test_prompt_factory", return 'Other Tests'
"test_router",
"test_text_completion",
]
):
return f"Cross-Provider Tests ({test_file})"
return "Other Tests"
def format_duration(seconds: float) -> str: def format_duration(seconds: float) -> str:
"""Format duration in human-readable format""" """Format duration in human-readable format"""
@ -100,355 +89,290 @@ def format_duration(seconds: float) -> str:
return f"{hours}h {minutes}m" return f"{hours}h {minutes}m"
def generate_markdown_report( def generate_markdown_report(junit_xml_path: str, output_path: str, tag: str = None, commit: str = None):
junit_xml_path: str, output_path: str, tag: str = None, commit: str = None
):
"""Generate a beautiful markdown report from JUnit XML""" """Generate a beautiful markdown report from JUnit XML"""
try: try:
tree = ET.parse(junit_xml_path) tree = ET.parse(junit_xml_path)
root = tree.getroot() root = tree.getroot()
# Handle both testsuite and testsuites root # Handle both testsuite and testsuites root
if root.tag == "testsuites": if root.tag == 'testsuites':
suites = root.findall("testsuite") suites = root.findall('testsuite')
else: else:
suites = [root] suites = [root]
# Overall statistics # Overall statistics
total_tests = 0 total_tests = 0
total_failures = 0 total_failures = 0
total_errors = 0 total_errors = 0
total_skipped = 0 total_skipped = 0
total_time = 0.0 total_time = 0.0
# Provider breakdown # Provider breakdown
provider_stats = defaultdict( provider_stats = defaultdict(lambda: {'passed': 0, 'failed': 0, 'skipped': 0, 'errors': 0, 'time': 0.0})
lambda: {"passed": 0, "failed": 0, "skipped": 0, "errors": 0, "time": 0.0}
)
provider_tests = defaultdict(list) provider_tests = defaultdict(list)
for suite in suites: for suite in suites:
total_tests += int(suite.get("tests", 0)) total_tests += int(suite.get('tests', 0))
total_failures += int(suite.get("failures", 0)) total_failures += int(suite.get('failures', 0))
total_errors += int(suite.get("errors", 0)) total_errors += int(suite.get('errors', 0))
total_skipped += int(suite.get("skipped", 0)) total_skipped += int(suite.get('skipped', 0))
total_time += float(suite.get("time", 0)) total_time += float(suite.get('time', 0))
for testcase in suite.findall("testcase"): for testcase in suite.findall('testcase'):
classname = testcase.get("classname", "") classname = testcase.get('classname', '')
test_name = testcase.get("name", "") test_name = testcase.get('name', '')
test_time = float(testcase.get("time", 0)) test_time = float(testcase.get('time', 0))
# Extract test file name from classname # Extract test file name from classname
if "." in classname: if '.' in classname:
parts = classname.split(".") parts = classname.split('.')
test_file = parts[-2] if len(parts) > 1 else "unknown" test_file = parts[-2] if len(parts) > 1 else 'unknown'
else: else:
test_file = "unknown" test_file = 'unknown'
provider = get_provider_from_test_file(test_file) provider = get_provider_from_test_file(test_file)
provider_stats[provider]["time"] += test_time provider_stats[provider]['time'] += test_time
# Check test status # Check test status
if testcase.find("failure") is not None: if testcase.find('failure') is not None:
provider_stats[provider]["failed"] += 1 provider_stats[provider]['failed'] += 1
failure = testcase.find("failure") failure = testcase.find('failure')
failure_msg = ( failure_msg = failure.get('message', '') if failure is not None else ''
failure.get("message", "") if failure is not None else "" provider_tests[provider].append({
) 'name': test_name,
provider_tests[provider].append( 'status': 'FAILED',
{ 'time': test_time,
"name": test_name, 'message': failure_msg
"status": "FAILED", })
"time": test_time, elif testcase.find('error') is not None:
"message": failure_msg, provider_stats[provider]['errors'] += 1
} error = testcase.find('error')
) error_msg = error.get('message', '') if error is not None else ''
elif testcase.find("error") is not None: provider_tests[provider].append({
provider_stats[provider]["errors"] += 1 'name': test_name,
error = testcase.find("error") 'status': 'ERROR',
error_msg = error.get("message", "") if error is not None else "" 'time': test_time,
provider_tests[provider].append( 'message': error_msg
{ })
"name": test_name, elif testcase.find('skipped') is not None:
"status": "ERROR", provider_stats[provider]['skipped'] += 1
"time": test_time, skip = testcase.find('skipped')
"message": error_msg, skip_msg = skip.get('message', '') if skip is not None else ''
} provider_tests[provider].append({
) 'name': test_name,
elif testcase.find("skipped") is not None: 'status': 'SKIPPED',
provider_stats[provider]["skipped"] += 1 'time': test_time,
skip = testcase.find("skipped") 'message': skip_msg
skip_msg = skip.get("message", "") if skip is not None else "" })
provider_tests[provider].append(
{
"name": test_name,
"status": "SKIPPED",
"time": test_time,
"message": skip_msg,
}
)
else: else:
provider_stats[provider]["passed"] += 1 provider_stats[provider]['passed'] += 1
provider_tests[provider].append( provider_tests[provider].append({
{ 'name': test_name,
"name": test_name, 'status': 'PASSED',
"status": "PASSED", 'time': test_time,
"time": test_time, 'message': ''
"message": "", })
}
)
passed = total_tests - total_failures - total_errors - total_skipped passed = total_tests - total_failures - total_errors - total_skipped
# Generate the markdown report # Generate the markdown report
with open(output_path, "w") as f: with open(output_path, 'w') as f:
# Header # Header
f.write("# LLM Translation Test Results\n\n") f.write("# LLM Translation Test Results\n\n")
# Metadata table # Metadata table
f.write("## Test Run Information\n\n") f.write("## Test Run Information\n\n")
f.write("| Field | Value |\n") f.write("| Field | Value |\n")
f.write("|-------|-------|\n") f.write("|-------|-------|\n")
f.write(f"| **Tag** | `{tag or 'N/A'}` |\n") f.write(f"| **Tag** | `{tag or 'N/A'}` |\n")
f.write( f.write(f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n")
f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n"
)
f.write(f"| **Commit** | `{commit or 'N/A'}` |\n") f.write(f"| **Commit** | `{commit or 'N/A'}` |\n")
f.write(f"| **Duration** | {format_duration(total_time)} |\n") f.write(f"| **Duration** | {format_duration(total_time)} |\n")
f.write("\n") f.write("\n")
# Overall statistics with visual elements # Overall statistics with visual elements
f.write("## Overall Statistics\n\n") f.write("## Overall Statistics\n\n")
# Summary box # Summary box
f.write("```\n") f.write("```\n")
f.write(f"Total Tests: {total_tests}\n") f.write(f"Total Tests: {total_tests}\n")
f.write( f.write(f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n" f.write(f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
) f.write(f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write( f.write(f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(
f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(
f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write("```\n\n") f.write("```\n\n")
# Provider summary table # Provider summary table
f.write("## Results by Provider\n\n") f.write("## Results by Provider\n\n")
f.write( f.write("| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n")
"| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n" f.write("|----------|-------|------|------|-------|------|-----------|----------|")
)
f.write(
"|----------|-------|------|------|-------|------|-----------|----------|"
)
# Sort providers: specific providers first, then cross-provider tests # Sort providers: specific providers first, then cross-provider tests
sorted_providers = [] sorted_providers = []
cross_provider = [] cross_provider = []
for p in sorted(provider_stats.keys()): for p in sorted(provider_stats.keys()):
if "Cross-Provider" in p or p == "Other Tests": if 'Cross-Provider' in p or p == 'Other Tests':
cross_provider.append(p) cross_provider.append(p)
else: else:
sorted_providers.append(p) sorted_providers.append(p)
all_providers = sorted_providers + cross_provider all_providers = sorted_providers + cross_provider
for provider in all_providers: for provider in all_providers:
stats = provider_stats[provider] stats = provider_stats[provider]
total = ( total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
stats["passed"] pass_rate = (stats['passed'] / total * 100) if total > 0 else 0
+ stats["failed"]
+ stats["errors"] f.write(f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | ")
+ stats["skipped"]
)
pass_rate = (stats["passed"] / total * 100) if total > 0 else 0
f.write(
f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | "
)
f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ") f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ")
f.write(f"{format_duration(stats['time'])} |") f.write(f"{format_duration(stats['time'])} |")
# Detailed test results by provider # Detailed test results by provider
f.write("\n\n## Detailed Test Results\n\n") f.write("\n\n## Detailed Test Results\n\n")
for provider in sorted_providers: for provider in sorted_providers:
if provider_tests[provider]: if provider_tests[provider]:
stats = provider_stats[provider] stats = provider_stats[provider]
total = ( total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
stats["passed"]
+ stats["failed"]
+ stats["errors"]
+ stats["skipped"]
)
f.write(f"### {provider}\n\n") f.write(f"### {provider}\n\n")
f.write(f"**Summary:** {stats['passed']}/{total} passed ") f.write(f"**Summary:** {stats['passed']}/{total} passed ")
f.write( f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) ")
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) "
)
f.write(f"in {format_duration(stats['time'])}\n\n") f.write(f"in {format_duration(stats['time'])}\n\n")
# Group tests by status # Group tests by status
tests_by_status = defaultdict(list) tests_by_status = defaultdict(list)
for test in provider_tests[provider]: for test in provider_tests[provider]:
tests_by_status[test["status"]].append(test) tests_by_status[test['status']].append(test)
# Show failed tests first (if any) # Show failed tests first (if any)
if tests_by_status["FAILED"]: if tests_by_status['FAILED']:
f.write("<details>\n<summary>Failed Tests</summary>\n\n") f.write("<details>\n<summary>Failed Tests</summary>\n\n")
for test in tests_by_status["FAILED"]: for test in tests_by_status['FAILED']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n") f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
if test["message"]: if test['message']:
# Truncate long error messages # Truncate long error messages
msg = ( msg = test['message'][:200] + '...' if len(test['message']) > 200 else test['message']
test["message"][:200] + "..."
if len(test["message"]) > 200
else test["message"]
)
f.write(f" > {msg}\n") f.write(f" > {msg}\n")
f.write("\n</details>\n\n") f.write("\n</details>\n\n")
# Show errors (if any) # Show errors (if any)
if tests_by_status["ERROR"]: if tests_by_status['ERROR']:
f.write("<details>\n<summary>Error Tests</summary>\n\n") f.write("<details>\n<summary>Error Tests</summary>\n\n")
for test in tests_by_status["ERROR"]: for test in tests_by_status['ERROR']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n") f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
f.write("\n</details>\n\n") f.write("\n</details>\n\n")
# Show passed tests in collapsible section # Show passed tests in collapsible section
if tests_by_status["PASSED"]: if tests_by_status['PASSED']:
f.write("<details>\n<summary>Passed Tests</summary>\n\n") f.write("<details>\n<summary>Passed Tests</summary>\n\n")
for test in tests_by_status["PASSED"]: for test in tests_by_status['PASSED']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n") f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
f.write("\n</details>\n\n") f.write("\n</details>\n\n")
# Show skipped tests (if any) # Show skipped tests (if any)
if tests_by_status["SKIPPED"]: if tests_by_status['SKIPPED']:
f.write("<details>\n<summary>Skipped Tests</summary>\n\n") f.write("<details>\n<summary>Skipped Tests</summary>\n\n")
for test in tests_by_status["SKIPPED"]: for test in tests_by_status['SKIPPED']:
f.write(f"- `{test['name']}`\n") f.write(f"- `{test['name']}`\n")
f.write("\n</details>\n\n") f.write("\n</details>\n\n")
# Cross-provider tests in a separate section # Cross-provider tests in a separate section
if cross_provider: if cross_provider:
f.write("### Cross-Provider Tests\n\n") f.write("### Cross-Provider Tests\n\n")
for provider in cross_provider: for provider in cross_provider:
if provider_tests[provider]: if provider_tests[provider]:
stats = provider_stats[provider] stats = provider_stats[provider]
total = ( total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
stats["passed"]
+ stats["failed"]
+ stats["errors"]
+ stats["skipped"]
)
f.write(f"#### {provider}\n\n") f.write(f"#### {provider}\n\n")
f.write(f"**Summary:** {stats['passed']}/{total} passed ") f.write(f"**Summary:** {stats['passed']}/{total} passed ")
f.write( f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n")
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n"
)
# For cross-provider tests, just show counts # For cross-provider tests, just show counts
f.write(f"- Passed: {stats['passed']}\n") f.write(f"- Passed: {stats['passed']}\n")
if stats["failed"] > 0: if stats['failed'] > 0:
f.write(f"- Failed: {stats['failed']}\n") f.write(f"- Failed: {stats['failed']}\n")
if stats["errors"] > 0: if stats['errors'] > 0:
f.write(f"- Errors: {stats['errors']}\n") f.write(f"- Errors: {stats['errors']}\n")
if stats["skipped"] > 0: if stats['skipped'] > 0:
f.write(f"- Skipped: {stats['skipped']}\n") f.write(f"- Skipped: {stats['skipped']}\n")
f.write("\n") f.write("\n")
print_colored(f"Report generated: {output_path}", Colors.GREEN) print_colored(f"Report generated: {output_path}", Colors.GREEN)
except Exception as e: except Exception as e:
print_colored(f"Error generating report: {e}", Colors.RED) print_colored(f"Error generating report: {e}", Colors.RED)
raise raise
def run_tests(test_path: str = "tests/llm_translation/",
def run_tests( junit_xml: str = "test-results/junit.xml",
test_path: str = "tests/llm_translation/", report_path: str = "test-results/llm_translation_report.md",
junit_xml: str = "test-results/junit.xml", tag: str = None,
report_path: str = "test-results/llm_translation_report.md", commit: str = None) -> int:
tag: str = None,
commit: str = None,
) -> int:
"""Run the LLM translation tests and generate report""" """Run the LLM translation tests and generate report"""
# Create test results directory # Create test results directory
os.makedirs(os.path.dirname(junit_xml), exist_ok=True) os.makedirs(os.path.dirname(junit_xml), exist_ok=True)
print_colored("Starting LLM Translation Tests", Colors.BOLD + Colors.BLUE) print_colored("Starting LLM Translation Tests", Colors.BOLD + Colors.BLUE)
print_colored(f"Test directory: {test_path}", Colors.CYAN) print_colored(f"Test directory: {test_path}", Colors.CYAN)
print_colored(f"Output: {junit_xml}", Colors.CYAN) print_colored(f"Output: {junit_xml}", Colors.CYAN)
print() print()
# Run pytest # Run pytest
cmd = [ cmd = [
"uv", "uv", "run", "--no-sync", "pytest", test_path,
"run",
"--no-sync",
"pytest",
test_path,
f"--junitxml={junit_xml}", f"--junitxml={junit_xml}",
"-v", "-v",
"--tb=short", "--tb=short",
"--maxfail=500", "--maxfail=500",
"-n", "-n", "auto"
"auto",
] ]
# Add timeout if pytest-timeout is installed # Add timeout if pytest-timeout is installed
try: try:
subprocess.run( subprocess.run(["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"], capture_output=True, check=True)
capture_output=True,
check=True,
)
cmd.extend(["--timeout=300"]) cmd.extend(["--timeout=300"])
except: except:
print_colored( print_colored("Warning: pytest-timeout not installed, skipping timeout option", Colors.YELLOW)
"Warning: pytest-timeout not installed, skipping timeout option",
Colors.YELLOW,
)
print_colored("Running pytest with command:", Colors.YELLOW) print_colored("Running pytest with command:", Colors.YELLOW)
print(f" {' '.join(cmd)}") print(f" {' '.join(cmd)}")
print() print()
# Run the tests # Run the tests
result = subprocess.run(cmd, capture_output=False) result = subprocess.run(cmd, capture_output=False)
# Generate the report regardless of test outcome # Generate the report regardless of test outcome
if os.path.exists(junit_xml): if os.path.exists(junit_xml):
print() print()
print_colored("Generating test report...", Colors.BLUE) print_colored("Generating test report...", Colors.BLUE)
generate_markdown_report(junit_xml, report_path, tag, commit) generate_markdown_report(junit_xml, report_path, tag, commit)
# Print summary to console # Print summary to console
print() print()
print_colored("Test Summary:", Colors.BOLD + Colors.PURPLE) print_colored("Test Summary:", Colors.BOLD + Colors.PURPLE)
# Parse XML for quick summary # Parse XML for quick summary
tree = ET.parse(junit_xml) tree = ET.parse(junit_xml)
root = tree.getroot() root = tree.getroot()
if root.tag == "testsuites": if root.tag == 'testsuites':
suites = root.findall("testsuite") suites = root.findall('testsuite')
else: else:
suites = [root] suites = [root]
total = sum(int(s.get("tests", 0)) for s in suites) total = sum(int(s.get('tests', 0)) for s in suites)
failures = sum(int(s.get("failures", 0)) for s in suites) failures = sum(int(s.get('failures', 0)) for s in suites)
errors = sum(int(s.get("errors", 0)) for s in suites) errors = sum(int(s.get('errors', 0)) for s in suites)
skipped = sum(int(s.get("skipped", 0)) for s in suites) skipped = sum(int(s.get('skipped', 0)) for s in suites)
passed = total - failures - errors - skipped passed = total - failures - errors - skipped
print(f" Total: {total}") print(f" Total: {total}")
print_colored(f" Passed: {passed}", Colors.GREEN) print_colored(f" Passed: {passed}", Colors.GREEN)
if failures > 0: if failures > 0:
@ -457,75 +381,59 @@ def run_tests(
print_colored(f" Errors: {errors}", Colors.RED) print_colored(f" Errors: {errors}", Colors.RED)
if skipped > 0: if skipped > 0:
print_colored(f" Skipped: {skipped}", Colors.YELLOW) print_colored(f" Skipped: {skipped}", Colors.YELLOW)
if total > 0: if total > 0:
pass_rate = (passed / total) * 100 pass_rate = (passed / total) * 100
color = ( color = Colors.GREEN if pass_rate >= 80 else Colors.YELLOW if pass_rate >= 60 else Colors.RED
Colors.GREEN
if pass_rate >= 80
else Colors.YELLOW if pass_rate >= 60 else Colors.RED
)
print_colored(f" Pass Rate: {pass_rate:.1f}%", color) print_colored(f" Pass Rate: {pass_rate:.1f}%", color)
else: else:
print_colored("No test results found!", Colors.RED) print_colored("No test results found!", Colors.RED)
print() print()
print_colored("Test run complete!", Colors.BOLD + Colors.GREEN) print_colored("Test run complete!", Colors.BOLD + Colors.GREEN)
return result.returncode return result.returncode
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser(description="Run LLM Translation Tests") parser = argparse.ArgumentParser(description="Run LLM Translation Tests")
parser.add_argument( parser.add_argument("--test-path", default="tests/llm_translation/",
"--test-path", default="tests/llm_translation/", help="Path to test directory" help="Path to test directory")
) parser.add_argument("--junit-xml", default="test-results/junit.xml",
parser.add_argument( help="Path for JUnit XML output")
"--junit-xml", parser.add_argument("--report", default="test-results/llm_translation_report.md",
default="test-results/junit.xml", help="Path for markdown report")
help="Path for JUnit XML output",
)
parser.add_argument(
"--report",
default="test-results/llm_translation_report.md",
help="Path for markdown report",
)
parser.add_argument("--tag", help="Git tag or version") parser.add_argument("--tag", help="Git tag or version")
parser.add_argument("--commit", help="Git commit SHA") parser.add_argument("--commit", help="Git commit SHA")
args = parser.parse_args() args = parser.parse_args()
# Get git info if not provided # Get git info if not provided
if not args.commit: if not args.commit:
try: try:
result = subprocess.run( result = subprocess.run(["git", "rev-parse", "HEAD"],
["git", "rev-parse", "HEAD"], capture_output=True, text=True capture_output=True, text=True)
)
if result.returncode == 0: if result.returncode == 0:
args.commit = result.stdout.strip() args.commit = result.stdout.strip()
except: except:
pass pass
if not args.tag: if not args.tag:
try: try:
result = subprocess.run( result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"],
["git", "describe", "--tags", "--abbrev=0"], capture_output=True, text=True)
capture_output=True,
text=True,
)
if result.returncode == 0: if result.returncode == 0:
args.tag = result.stdout.strip() args.tag = result.stdout.strip()
except: except:
pass pass
exit_code = run_tests( exit_code = run_tests(
test_path=args.test_path, test_path=args.test_path,
junit_xml=args.junit_xml, junit_xml=args.junit_xml,
report_path=args.report, report_path=args.report,
tag=args.tag, tag=args.tag,
commit=args.commit, commit=args.commit
) )
sys.exit(exit_code) sys.exit(exit_code)

View File

@ -9,6 +9,7 @@ from pathlib import Path
import testing.postgresql import testing.postgresql
DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE) DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE)
DEFAULT_BASE_BRANCH = "litellm_internal_staging" DEFAULT_BASE_BRANCH = "litellm_internal_staging"

View File

@ -6,6 +6,7 @@ from tabulate import tabulate
from termcolor import colored from termcolor import colored
import os import os
# Define the list of models to benchmark # Define the list of models to benchmark
# select any LLM listed here: https://docs.litellm.ai/docs/providers # select any LLM listed here: https://docs.litellm.ai/docs/providers
models = ["gpt-3.5-turbo", "claude-2"] models = ["gpt-3.5-turbo", "claude-2"]

View File

@ -22,6 +22,7 @@ from fastapi import FastAPI, HTTPException, Header, Query
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import uvicorn import uvicorn
app = FastAPI( app = FastAPI(
title="Braintrust Prompt Wrapper", title="Braintrust Prompt Wrapper",
description="Wrapper server for Braintrust prompts to work with LiteLLM", description="Wrapper server for Braintrust prompts to work with LiteLLM",

View File

@ -2,7 +2,7 @@
Simple xAI Voice Agent using LiveKit SDK with LiteLLM Gateway Simple xAI Voice Agent using LiveKit SDK with LiteLLM Gateway
This example shows how to use LiveKit's xAI realtime plugin through LiteLLM proxy. This example shows how to use LiveKit's xAI realtime plugin through LiteLLM proxy.
LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI, LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI,
and Azure realtime APIs without changing your agent code. and Azure realtime APIs without changing your agent code.
""" """

View File

@ -1,14 +1,14 @@
""" """
LiteLLM Migration Script! LiteLLM Migration Script!
Takes a config.yaml and calls /model/new Takes a config.yaml and calls /model/new
Inputs: Inputs:
- File path to config.yaml - File path to config.yaml
- Proxy base url to your hosted proxy - Proxy base url to your hosted proxy
Step 1: Reads your config.yaml Step 1: Reads your config.yaml
Step 2: reads `model_list` and loops through all models Step 2: reads `model_list` and loops through all models
Step 3: calls `<proxy-base-url>/model/new` for each model Step 3: calls `<proxy-base-url>/model/new` for each model
""" """

View File

@ -518,7 +518,8 @@ if __name__ == "__main__":
print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply") print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply")
print("=" * 80) print("=" * 80)
print("\nExample curl command:") print("\nExample curl command:")
print(f""" print(
f"""
curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\ curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
-H "Authorization: Bearer {bearer_token}" \\ -H "Authorization: Bearer {bearer_token}" \\
-H "Content-Type: application/json" \\ -H "Content-Type: application/json" \\
@ -532,7 +533,8 @@ curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
}} }}
] ]
}}' }}'
""") """
)
print("=" * 80) print("=" * 80)
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)

View File

@ -34,7 +34,8 @@ async def check_view_exists(): # noqa: PLR0915
print("LiteLLM_VerificationTokenView Exists!") # noqa print("LiteLLM_VerificationTokenView Exists!") # noqa
except Exception: except Exception:
# If an error occurs, the view does not exist, so create it # If an error occurs, the view does not exist, so create it
await db.execute_raw(""" await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT SELECT
v.*, v.*,
@ -44,7 +45,8 @@ async def check_view_exists(): # noqa: PLR0915
t.rpm_limit AS team_rpm_limit t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
""") """
)
print("LiteLLM_VerificationTokenView Created!") # noqa print("LiteLLM_VerificationTokenView Created!") # noqa

View File

@ -14,74 +14,53 @@ from litellm.types.utils import StandardCallbackDynamicParams
class EnterpriseCallbackControls: class EnterpriseCallbackControls:
@staticmethod @staticmethod
def is_callback_disabled_dynamically( def is_callback_disabled_dynamically(
callback: litellm.CALLBACK_TYPES, callback: litellm.CALLBACK_TYPES,
litellm_params: dict, litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams, standard_callback_dynamic_params: StandardCallbackDynamicParams
) -> bool: ) -> bool:
""" """
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params. Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
Args: Args:
callback: The callback to check (can be string, CustomLogger instance, or callable) callback: The callback to check (can be string, CustomLogger instance, or callable)
litellm_params: Parameters containing proxy server request info litellm_params: Parameters containing proxy server request info
Returns: Returns:
bool: True if the callback should be disabled, False otherwise bool: True if the callback should be disabled, False otherwise
""" """
from litellm.litellm_core_utils.custom_logger_registry import ( from litellm.litellm_core_utils.custom_logger_registry import (
CustomLoggerRegistry, CustomLoggerRegistry,
)
try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(
litellm_params, standard_callback_dynamic_params
) )
verbose_logger.debug(
f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}"
)
verbose_logger.debug(
f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}"
)
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if (
not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling()
):
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(
f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
)
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = (
CustomLoggerRegistry.get_callback_str_from_class_type(
callback.__class__
)
)
if (
callback_str is not None
and callback_str.lower() in disabled_callbacks
):
verbose_logger.debug(
f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
)
return True
return False
except Exception as e:
verbose_logger.debug(f"Error checking disabled callbacks header: {str(e)}")
return False
try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling():
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__)
if callback_str is not None and callback_str.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
return False
except Exception as e:
verbose_logger.debug(
f"Error checking disabled callbacks header: {str(e)}"
)
return False
@staticmethod @staticmethod
def get_disabled_callbacks( def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams,
) -> Optional[List[str]]:
""" """
Get the disabled callbacks from the standard callback dynamic params. Get the disabled callbacks from the standard callback dynamic params.
""" """
@ -92,24 +71,18 @@ class EnterpriseCallbackControls:
request_headers = get_proxy_server_request_headers(litellm_params) request_headers = get_proxy_server_request_headers(litellm_params)
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None) disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
if disabled_callbacks is not None: if disabled_callbacks is not None:
disabled_callbacks = set( disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
[cb.strip().lower() for cb in disabled_callbacks.split(",")]
)
return list(disabled_callbacks) return list(disabled_callbacks)
######################################################### #########################################################
# check if disabled via request body # check if disabled via request body
######################################################### #########################################################
if ( if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
is not None
):
return standard_callback_dynamic_params.get(
"litellm_disabled_callbacks", None
)
return None return None
@staticmethod @staticmethod
def _should_allow_dynamic_callback_disabling(): def _should_allow_dynamic_callback_disabling():
import litellm import litellm
@ -117,14 +90,10 @@ class EnterpriseCallbackControls:
# Check if admin has disabled this feature # Check if admin has disabled this feature
if litellm.allow_dynamic_callback_disabling is not True: if litellm.allow_dynamic_callback_disabling is not True:
verbose_logger.debug( verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling")
"Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling"
)
return False return False
if premium_user: if premium_user:
return True return True
verbose_logger.warning( verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}" return False
)
return False

View File

@ -349,10 +349,8 @@ class BaseEmailLogger(CustomLogger):
) )
# Calculate percentage and alert threshold # Calculate percentage and alert threshold
percentage = ( percentage = threshold_pct if threshold_pct is not None else int(
threshold_pct EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100
if threshold_pct is not None
else int(EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100)
) )
threshold_fraction = percentage / 100.0 threshold_fraction = percentage / 100.0
alert_threshold_str = ( alert_threshold_str = (
@ -611,7 +609,9 @@ class BaseEmailLogger(CustomLogger):
continue continue
_id = user_info.token or user_info.user_id or "default_id" _id = user_info.token or user_info.user_id or "default_id"
_cache_key = f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}" _cache_key = (
f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
)
result = await _cache.async_get_cache(key=_cache_key) result = await _cache.async_get_cache(key=_cache_key)
if result is not None: if result is not None:
@ -630,9 +630,7 @@ class BaseEmailLogger(CustomLogger):
continue continue
recipient_emails = list(set(emails)) recipient_emails = list(set(emails))
event_message = ( event_message = f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
)
webhook_event = WebhookEvent( webhook_event = WebhookEvent(
event="max_budget_alert", event="max_budget_alert",
event_message=event_message, event_message=event_message,

View File

@ -15,6 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
from .base_email import BaseEmailLogger from .base_email import BaseEmailLogger
SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send" SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send"
@ -78,4 +79,4 @@ class SendGridEmailLogger(BaseEmailLogger):
verbose_logger.debug( verbose_logger.debug(
f"SendGrid response status={response.status_code}, body={response.text}" f"SendGrid response status={response.status_code}, body={response.text}"
) )
return return

View File

@ -1,7 +1,6 @@
""" """
This is the litellm SMTP email integration This is the litellm SMTP email integration
""" """
import asyncio import asyncio
from typing import List from typing import List

View File

@ -1,7 +1,6 @@
""" """
Enterprise specific logging utils Enterprise specific logging utils
""" """
from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata

View File

@ -153,11 +153,11 @@ async def get_audit_logs(
# Return paginated response # Return paginated response
return PaginatedAuditLogResponse( return PaginatedAuditLogResponse(
audit_logs=( audit_logs=[
[AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs] AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs
if audit_logs ]
else [] if audit_logs
), else [],
total=total_count, total=total_count,
page=page, page=page,
page_size=page_size, page_size=page_size,

View File

@ -7,4 +7,4 @@ including custom SSO handlers and advanced authentication features.
from .custom_sso_handler import EnterpriseCustomSSOHandler from .custom_sso_handler import EnterpriseCustomSSOHandler
__all__ = ["EnterpriseCustomSSOHandler"] __all__ = ["EnterpriseCustomSSOHandler"]

View File

@ -53,9 +53,7 @@ class CheckBatchCost:
"user_api_key_alias": getattr(user_row, "user_alias", None), "user_api_key_alias": getattr(user_row, "user_alias", None),
} }
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}")
f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}"
)
return {} return {}
async def _cleanup_stale_managed_objects(self) -> None: async def _cleanup_stale_managed_objects(self) -> None:
@ -64,22 +62,11 @@ class CheckBatchCost:
in non-terminal states as 'stale_expired'. These will never complete and in non-terminal states as 'stale_expired'. These will never complete and
should not be polled. should not be polled.
""" """
cutoff = datetime.now(timezone.utc) - timedelta( cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
)
result = await self.prisma_client.db.litellm_managedobjecttable.update_many( result = await self.prisma_client.db.litellm_managedobjecttable.update_many(
where={ where={
"file_purpose": "batch", "file_purpose": "batch",
"status": { "status": {"not_in": ["completed", "complete", "failed", "expired", "cancelled", "stale_expired"]},
"not_in": [
"completed",
"complete",
"failed",
"expired",
"cancelled",
"stale_expired",
]
},
"created_at": {"lt": cutoff}, "created_at": {"lt": cutoff},
}, },
data={"status": "stale_expired"}, data={"status": "stale_expired"},
@ -133,12 +120,9 @@ class CheckBatchCost:
try: try:
from litellm.integrations.prometheus import PrometheusLogger from litellm.integrations.prometheus import PrometheusLogger
prom_logger = PrometheusLogger.get_instance() prom_logger = PrometheusLogger.get_instance()
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(f"CheckBatchCost: could not get Prometheus logger: {e}")
f"CheckBatchCost: could not get Prometheus logger: {e}"
)
prom_logger = None prom_logger = None
processed_models: List[Tuple[Optional[str], Optional[str]]] = [] processed_models: List[Tuple[Optional[str], Optional[str]]] = []
@ -177,11 +161,7 @@ class CheckBatchCost:
order={"created_at": "asc"}, order={"created_at": "asc"},
) )
except Exception as query_err: except Exception as query_err:
if ( if "batch_processed" not in str(query_err).lower() and "unknown column" not in str(query_err).lower() and "does not exist" not in str(query_err).lower():
"batch_processed" not in str(query_err).lower()
and "unknown column" not in str(query_err).lower()
and "does not exist" not in str(query_err).lower()
):
raise raise
# Permanent schema gap — cache the result so future cycles skip straight to fallback # Permanent schema gap — cache the result so future cycles skip straight to fallback
self._has_batch_processed_column = False self._has_batch_processed_column = False
@ -236,13 +216,14 @@ class CheckBatchCost:
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}" f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
) )
if prom_logger: if prom_logger:
prom_logger.record_check_batch_cost_error( prom_logger.record_check_batch_cost_error("provider_retrieval_error")
"provider_retrieval_error"
)
continue continue
## RETRIEVE THE BATCH JOB OUTPUT FILE ## RETRIEVE THE BATCH JOB OUTPUT FILE
if response.status == "completed" and response.output_file_id is not None: if (
response.status == "completed"
and response.output_file_id is not None
):
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"Batch ID: {batch_id} is complete, tracking cost and usage" f"Batch ID: {batch_id} is complete, tracking cost and usage"
) )
@ -269,25 +250,20 @@ class CheckBatchCost:
decoded = _is_base64_encoded_unified_file_id(raw_output_file_id) decoded = _is_base64_encoded_unified_file_id(raw_output_file_id)
if decoded: if decoded:
try: try:
raw_output_file_id = decoded.split("llm_output_file_id,")[ raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0]
1
].split(";")[0]
except (IndexError, AttributeError): except (IndexError, AttributeError):
pass pass
credentials = ( credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {}
self.llm_router.get_deployment_credentials_with_provider(model_id)
or {}
)
_file_content = await afile_content( _file_content = await afile_content(
file_id=raw_output_file_id, file_id=raw_output_file_id,
**credentials, **credentials,
) )
# Access content - handle both direct attribute and method call # Access content - handle both direct attribute and method call
if hasattr(_file_content, "content"): if hasattr(_file_content, 'content'):
content_bytes = _file_content.content # type: ignore[union-attr] content_bytes = _file_content.content # type: ignore[union-attr]
elif hasattr(_file_content, "read"): elif hasattr(_file_content, 'read'):
content_bytes = await _file_content.read() # type: ignore[misc] content_bytes = await _file_content.read() # type: ignore[misc]
else: else:
content_bytes = _file_content # type: ignore[assignment] content_bytes = _file_content # type: ignore[assignment]
@ -314,9 +290,7 @@ class CheckBatchCost:
f"Skipping job {unified_object_id} because it is not a valid deployment info" f"Skipping job {unified_object_id} because it is not a valid deployment info"
) )
if prom_logger: if prom_logger:
prom_logger.record_check_batch_cost_error( prom_logger.record_check_batch_cost_error("deployment_not_found")
"deployment_not_found"
)
continue continue
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
litellm_model_name = deployment_info.litellm_params.model litellm_model_name = deployment_info.litellm_params.model
@ -328,11 +302,7 @@ class CheckBatchCost:
# Pass deployment model_info so custom batch pricing # Pass deployment model_info so custom batch pricing
# (input_cost_per_token_batches etc.) is used for cost calc # (input_cost_per_token_batches etc.) is used for cost calc
deployment_model_info = ( deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}
deployment_info.model_info.model_dump()
if deployment_info.model_info
else {}
)
batch_cost, batch_usage, batch_models = ( batch_cost, batch_usage, batch_models = (
await calculate_batch_cost_and_usage( await calculate_batch_cost_and_usage(
file_content_dictionary=file_content_as_dict, file_content_dictionary=file_content_as_dict,
@ -379,9 +349,7 @@ class CheckBatchCost:
# Record batch duration (completed_at - created_at) # Record batch duration (completed_at - created_at)
if prom_logger and response.completed_at and response.created_at: if prom_logger and response.completed_at and response.created_at:
duration_seconds = float( duration_seconds = float(response.completed_at - response.created_at)
response.completed_at - response.created_at
)
if duration_seconds >= 0: if duration_seconds >= 0:
prom_logger.record_managed_batch_duration( prom_logger.record_managed_batch_duration(
duration_seconds=duration_seconds, duration_seconds=duration_seconds,
@ -390,9 +358,7 @@ class CheckBatchCost:
) )
# Track this job for the final metrics summary # Track this job for the final metrics summary
processed_models.append( processed_models.append((model_name, str(llm_provider) if llm_provider else None))
(model_name, str(llm_provider) if llm_provider else None)
)
# mark the job as complete # mark the job as complete
try: try:

View File

@ -33,7 +33,9 @@ class CheckResponsesCost:
self.prisma_client: PrismaClient = prisma_client self.prisma_client: PrismaClient = prisma_client
self.llm_router: Router = llm_router self.llm_router: Router = llm_router
async def _expire_stale_rows(self, cutoff: datetime, batch_size: int) -> int: async def _expire_stale_rows(
self, cutoff: datetime, batch_size: int
) -> int:
"""Execute the bounded UPDATE that marks stale rows as 'stale_expired'. """Execute the bounded UPDATE that marks stale rows as 'stale_expired'.
Isolated so it can be swapped / mocked in tests without touching the Isolated so it can be swapped / mocked in tests without touching the
@ -72,9 +74,7 @@ class CheckResponsesCost:
rows per invocation to avoid overwhelming the DB when there is a large rows per invocation to avoid overwhelming the DB when there is a large
backlog. backlog.
""" """
cutoff = datetime.now(timezone.utc) - timedelta( cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
)
result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE) result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE)
if result > 0: if result > 0:
verbose_proxy_logger.warning( verbose_proxy_logger.warning(
@ -105,7 +105,7 @@ class CheckResponsesCost:
take=MAX_OBJECTS_PER_POLL_CYCLE, take=MAX_OBJECTS_PER_POLL_CYCLE,
order={"created_at": "asc"}, order={"created_at": "asc"},
) )
verbose_proxy_logger.debug(f"Found {len(jobs)} response jobs to check") verbose_proxy_logger.debug(f"Found {len(jobs)} response jobs to check")
completed_jobs = [] completed_jobs = []
@ -120,33 +120,29 @@ class CheckResponsesCost:
# Get the stored response object to extract model information # Get the stored response object to extract model information
stored_response = job.file_object stored_response = job.file_object
model_name = stored_response.get("model", None) model_name = stored_response.get("model", None)
# Decrypt the response ID # Decrypt the response ID
responses_id_security, _, _ = ( responses_id_security, _, _ = ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
)
# Prepare metadata with model information for cost tracking # Prepare metadata with model information for cost tracking
litellm_metadata = { litellm_metadata = {
"user_api_key_user_id": job.created_by or "default-user-id", "user_api_key_user_id": job.created_by or "default-user-id",
} }
# Add model information if available # Add model information if available
if model_name: if model_name:
litellm_metadata["model"] = model_name litellm_metadata["model"] = model_name
litellm_metadata["model_group"] = ( litellm_metadata["model_group"] = model_name # Use same value for model_group
model_name # Use same value for model_group
)
response = await litellm.aget_responses( response = await litellm.aget_responses(
response_id=responses_id_security, response_id=responses_id_security,
litellm_metadata=litellm_metadata, litellm_metadata=litellm_metadata,
) )
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Response {unified_object_id} status: {response.status}, model: {model_name}" f"Response {unified_object_id} status: {response.status}, model: {model_name}"
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"Skipping job {unified_object_id} due to error: {e}" f"Skipping job {unified_object_id} due to error: {e}"
@ -159,7 +155,7 @@ class CheckResponsesCost:
f"Response {unified_object_id} is complete. Cost automatically tracked by aget_responses." f"Response {unified_object_id} is complete. Cost automatically tracked by aget_responses."
) )
completed_jobs.append(job) completed_jobs.append(job)
elif response.status in ["failed", "cancelled"]: elif response.status in ["failed", "cancelled"]:
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"Response {unified_object_id} has status {response.status}, marking as complete" f"Response {unified_object_id} has status {response.status}, marking as complete"
@ -175,3 +171,4 @@ class CheckResponsesCost:
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"Marked {len(completed_jobs)} response jobs as completed" f"Marked {len(completed_jobs)} response jobs as completed"
) )

View File

@ -41,7 +41,7 @@ class _PROXY_LiteLLMManagedVectorStores(
): ):
""" """
Managed vector stores with target_model_names support. Managed vector stores with target_model_names support.
This class provides functionality to: This class provides functionality to:
- Create vector stores across multiple models - Create vector stores across multiple models
- Retrieve vector stores by unified ID - Retrieve vector stores by unified ID
@ -77,14 +77,14 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> str: ) -> str:
""" """
Generate the format string for the unified vector store ID. Generate the format string for the unified vector store ID.
Format: Format:
litellm_proxy:vector_store;unified_id,<uuid>;target_model_names,<models>;resource_id,<vs_id>;model_id,<model_id> litellm_proxy:vector_store;unified_id,<uuid>;target_model_names,<models>;resource_id,<vs_id>;model_id,<model_id>
""" """
# VectorStoreCreateResponse is a TypedDict, so resource_object is a dictionary # VectorStoreCreateResponse is a TypedDict, so resource_object is a dictionary
# Extract provider resource ID from the response # Extract provider resource ID from the response
provider_resource_id = resource_object.get("id", "") provider_resource_id = resource_object.get("id", "")
# Model ID is stored in hidden params if the response object supports it # Model ID is stored in hidden params if the response object supports it
# For TypedDict responses, we need to check if _hidden_params was added # For TypedDict responses, we need to check if _hidden_params was added
hidden_params: Dict[str, Any] = {} hidden_params: Dict[str, Any] = {}
@ -109,18 +109,20 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> VectorStoreCreateResponse: ) -> VectorStoreCreateResponse:
""" """
Create a vector store for a specific model. Create a vector store for a specific model.
Args: Args:
llm_router: LiteLLM router instance llm_router: LiteLLM router instance
model: Model name to create vector store for model: Model name to create vector store for
request_data: Request data for vector store creation request_data: Request data for vector store creation
litellm_parent_otel_span: OpenTelemetry span for tracing litellm_parent_otel_span: OpenTelemetry span for tracing
Returns: Returns:
VectorStoreCreateResponse from the provider VectorStoreCreateResponse from the provider
""" """
# Use the router to create the vector store # Use the router to create the vector store
response = await llm_router.avector_store_create(model=model, **request_data) response = await llm_router.avector_store_create(
model=model, **request_data
)
return response return response
# ============================================================================ # ============================================================================
@ -137,14 +139,14 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> VectorStoreCreateResponse: ) -> VectorStoreCreateResponse:
""" """
Create a vector store across multiple models. Create a vector store across multiple models.
Args: Args:
create_request: Vector store creation request parameters create_request: Vector store creation request parameters
llm_router: LiteLLM router instance llm_router: LiteLLM router instance
target_model_names_list: List of target model names target_model_names_list: List of target model names
litellm_parent_otel_span: OpenTelemetry span for tracing litellm_parent_otel_span: OpenTelemetry span for tracing
user_api_key_dict: User API key authentication details user_api_key_dict: User API key authentication details
Returns: Returns:
VectorStoreCreateResponse with unified ID VectorStoreCreateResponse with unified ID
""" """
@ -194,7 +196,7 @@ class _PROXY_LiteLLMManagedVectorStores(
# VectorStoreCreateResponse is a TypedDict, so we need to create a new dict with the unified ID # VectorStoreCreateResponse is a TypedDict, so we need to create a new dict with the unified ID
response = responses[0].copy() response = responses[0].copy()
response["id"] = unified_id response["id"] = unified_id
verbose_logger.info( verbose_logger.info(
f"Successfully created managed vector store with unified ID: {unified_id}" f"Successfully created managed vector store with unified ID: {unified_id}"
) )
@ -210,13 +212,13 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
List vector stores created by a user. List vector stores created by a user.
Args: Args:
user_api_key_dict: User API key authentication details user_api_key_dict: User API key authentication details
limit: Maximum number of vector stores to return limit: Maximum number of vector stores to return
after: Cursor for pagination after: Cursor for pagination
order: Sort order ('asc' or 'desc') order: Sort order ('asc' or 'desc')
Returns: Returns:
Dictionary with list of vector stores and pagination info Dictionary with list of vector stores and pagination info
""" """
@ -236,23 +238,23 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> bool: ) -> bool:
""" """
Check if user has access to a vector store. Check if user has access to a vector store.
Args: Args:
vector_store_id: The unified vector store ID vector_store_id: The unified vector store ID
user_api_key_dict: User API key authentication details user_api_key_dict: User API key authentication details
Returns: Returns:
True if user has access, False otherwise True if user has access, False otherwise
""" """
is_unified_id = is_base64_encoded_unified_id(vector_store_id) is_unified_id = is_base64_encoded_unified_id(vector_store_id)
if is_unified_id: if is_unified_id:
# Check access for managed vector store # Check access for managed vector store
return await self.can_user_access_unified_resource_id( return await self.can_user_access_unified_resource_id(
vector_store_id, vector_store_id,
user_api_key_dict, user_api_key_dict,
) )
# Not a managed vector store, allow access # Not a managed vector store, allow access
return True return True
@ -261,22 +263,24 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> bool: ) -> bool:
""" """
Check if user has access to a managed vector store in request data. Check if user has access to a managed vector store in request data.
Args: Args:
data: Request data containing vector_store_id data: Request data containing vector_store_id
user_api_key_dict: User API key authentication details user_api_key_dict: User API key authentication details
Returns: Returns:
True if this is a managed vector store and user has access True if this is a managed vector store and user has access
Raises: Raises:
HTTPException: If user doesn't have access HTTPException: If user doesn't have access
""" """
vector_store_id = cast(Optional[str], data.get("vector_store_id")) vector_store_id = cast(Optional[str], data.get("vector_store_id"))
is_unified_id = ( is_unified_id = (
is_base64_encoded_unified_id(vector_store_id) if vector_store_id else False is_base64_encoded_unified_id(vector_store_id)
if vector_store_id
else False
) )
if is_unified_id and vector_store_id: if is_unified_id and vector_store_id:
if await self.can_user_access_unified_resource_id( if await self.can_user_access_unified_resource_id(
vector_store_id, user_api_key_dict vector_store_id, user_api_key_dict
@ -287,7 +291,7 @@ class _PROXY_LiteLLMManagedVectorStores(
status_code=403, status_code=403,
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}", detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
) )
return False return False
# ============================================================================ # ============================================================================
@ -303,18 +307,18 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> Union[Exception, str, Dict, None]: ) -> Union[Exception, str, Dict, None]:
""" """
Pre-call hook to handle vector store operations. Pre-call hook to handle vector store operations.
This hook intercepts vector store requests and: This hook intercepts vector store requests and:
- Validates access for managed vector stores - Validates access for managed vector stores
- Transforms unified IDs to provider-specific IDs - Transforms unified IDs to provider-specific IDs
- Adds model routing information - Adds model routing information
Args: Args:
user_api_key_dict: User API key authentication details user_api_key_dict: User API key authentication details
cache: Cache instance cache: Cache instance
data: Request data data: Request data
call_type: Type of call being made call_type: Type of call being made
Returns: Returns:
Modified request data or None Modified request data or None
""" """
@ -326,40 +330,40 @@ class _PROXY_LiteLLMManagedVectorStores(
# Handle vector store search operations # Handle vector store search operations
if call_type == "avector_store_search": if call_type == "avector_store_search":
vector_store_id = data.get("vector_store_id") vector_store_id = data.get("vector_store_id")
if vector_store_id: if vector_store_id:
# Check if it's a managed vector store ID # Check if it's a managed vector store ID
decoded_id = is_base64_encoded_unified_id(vector_store_id) decoded_id = is_base64_encoded_unified_id(vector_store_id)
if decoded_id: if decoded_id:
verbose_logger.debug( verbose_logger.debug(
f"Processing managed vector store search: {vector_store_id}" f"Processing managed vector store search: {vector_store_id}"
) )
# Check access # Check access
has_access = await self.can_user_access_unified_resource_id( has_access = await self.can_user_access_unified_resource_id(
vector_store_id, user_api_key_dict vector_store_id, user_api_key_dict
) )
if not has_access: if not has_access:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}", detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
) )
# Parse the unified ID to extract components # Parse the unified ID to extract components
parsed_id = parse_unified_id(vector_store_id) parsed_id = parse_unified_id(vector_store_id)
if parsed_id: if parsed_id:
# Extract the model ID and provider resource ID # Extract the model ID and provider resource ID
model_id = parsed_id.get("model_id") model_id = parsed_id.get("model_id")
provider_resource_id = parsed_id.get("provider_resource_id") provider_resource_id = parsed_id.get("provider_resource_id")
target_model_names = parsed_id.get("target_model_names", []) target_model_names = parsed_id.get("target_model_names", [])
verbose_logger.debug( verbose_logger.debug(
f"Decoded vector store - model_id: {model_id}, provider_resource_id: {provider_resource_id}, target_model_names: {target_model_names}" f"Decoded vector store - model_id: {model_id}, provider_resource_id: {provider_resource_id}, target_model_names: {target_model_names}"
) )
# Determine which model to use for routing # Determine which model to use for routing
# Priority: model_id (deployment ID) > first target_model_name # Priority: model_id (deployment ID) > first target_model_name
routing_model = None routing_model = None
@ -367,28 +371,28 @@ class _PROXY_LiteLLMManagedVectorStores(
routing_model = model_id routing_model = model_id
elif target_model_names and len(target_model_names) > 0: elif target_model_names and len(target_model_names) > 0:
routing_model = target_model_names[0] routing_model = target_model_names[0]
# Set the model for routing # Set the model for routing
if routing_model: if routing_model:
data["model"] = routing_model data["model"] = routing_model
verbose_logger.info( verbose_logger.info(
f"Routing vector store search to model: {routing_model}" f"Routing vector store search to model: {routing_model}"
) )
# Replace the unified ID with the provider-specific ID # Replace the unified ID with the provider-specific ID
if provider_resource_id: if provider_resource_id:
data["vector_store_id"] = provider_resource_id data["vector_store_id"] = provider_resource_id
verbose_logger.debug( verbose_logger.debug(
f"Replaced unified ID with provider resource ID: {provider_resource_id}" f"Replaced unified ID with provider resource ID: {provider_resource_id}"
) )
# Handle vector store retrieve/delete operations # Handle vector store retrieve/delete operations
elif call_type in ("avector_store_retrieve", "avector_store_delete"): elif call_type in ("avector_store_retrieve", "avector_store_delete"):
await self.check_managed_vector_store_access(data, user_api_key_dict) await self.check_managed_vector_store_access(data, user_api_key_dict)
# If it's a managed vector store, we'll handle it in the endpoint # If it's a managed vector store, we'll handle it in the endpoint
# No need to transform here as the endpoint will route to the hook # No need to transform here as the endpoint will route to the hook
return data return data
# ============================================================================ # ============================================================================
@ -403,15 +407,15 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> Any: ) -> Any:
""" """
Post-call hook to transform responses. Post-call hook to transform responses.
This hook can be used to transform responses if needed. This hook can be used to transform responses if needed.
For now, it just passes through the response. For now, it just passes through the response.
Args: Args:
data: Request data data: Request data
user_api_key_dict: User API key authentication details user_api_key_dict: User API key authentication details
response: Response from the provider response: Response from the provider
Returns: Returns:
Potentially modified response Potentially modified response
""" """
@ -432,21 +436,21 @@ class _PROXY_LiteLLMManagedVectorStores(
) -> List[Dict]: ) -> List[Dict]:
""" """
Filter deployments based on vector store availability. Filter deployments based on vector store availability.
This is used by the router to select only deployments that have This is used by the router to select only deployments that have
the vector store available. the vector store available.
Note: This method signature is a compromise between CustomLogger and BaseManagedResource Note: This method signature is a compromise between CustomLogger and BaseManagedResource
parent classes which have incompatible signatures. The type: ignore[override] is necessary parent classes which have incompatible signatures. The type: ignore[override] is necessary
due to this multiple inheritance conflict. due to this multiple inheritance conflict.
Args: Args:
model: Model name model: Model name
healthy_deployments: List of healthy deployments healthy_deployments: List of healthy deployments
messages: Messages (unused for vector stores, required by CustomLogger interface) messages: Messages (unused for vector stores, required by CustomLogger interface)
request_kwargs: Request kwargs containing vector_store_id and mappings request_kwargs: Request kwargs containing vector_store_id and mappings
parent_otel_span: OpenTelemetry span for tracing parent_otel_span: OpenTelemetry span for tracing
Returns: Returns:
Filtered list of deployments Filtered list of deployments
""" """

View File

@ -2,6 +2,7 @@
Enterprise internal user management endpoints Enterprise internal user management endpoints
""" """
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth

View File

@ -147,12 +147,12 @@ async def list_vector_stores(
vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db( vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db(
prisma_client=prisma_client prisma_client=prisma_client
) )
# Also clean up in-memory registry to remove any deleted vector stores # Also clean up in-memory registry to remove any deleted vector stores
if litellm.vector_store_registry is not None: if litellm.vector_store_registry is not None:
db_vector_store_ids = { db_vector_store_ids = {
vs.get("vector_store_id") vs.get("vector_store_id")
for vs in vector_stores_from_db for vs in vector_stores_from_db
if vs.get("vector_store_id") if vs.get("vector_store_id")
} }
# Remove any in-memory vector stores that no longer exist in database # Remove any in-memory vector stores that no longer exist in database

View File

@ -39,23 +39,15 @@ class EmailEvent(str, enum.Enum):
soft_budget_crossed = "Soft Budget Crossed" soft_budget_crossed = "Soft Budget Crossed"
max_budget_alert = "Max Budget Alert" max_budget_alert = "Max Budget Alert"
class EmailEventSettings(BaseModel): class EmailEventSettings(BaseModel):
event: EmailEvent event: EmailEvent
enabled: bool enabled: bool
class EmailEventSettingsUpdateRequest(BaseModel): class EmailEventSettingsUpdateRequest(BaseModel):
settings: List[EmailEventSettings] settings: List[EmailEventSettings]
class EmailEventSettingsResponse(BaseModel): class EmailEventSettingsResponse(BaseModel):
settings: List[EmailEventSettings] settings: List[EmailEventSettings]
class DefaultEmailSettings(BaseModel): class DefaultEmailSettings(BaseModel):
"""Default settings for email events""" """Default settings for email events"""
settings: Dict[EmailEvent, bool] = Field( settings: Dict[EmailEvent, bool] = Field(
default_factory=lambda: { default_factory=lambda: {
EmailEvent.virtual_key_created: True, # On by default EmailEvent.virtual_key_created: True, # On by default
@ -65,12 +57,10 @@ class DefaultEmailSettings(BaseModel):
EmailEvent.max_budget_alert: True, # On by default EmailEvent.max_budget_alert: True, # On by default
} }
) )
def to_dict(self) -> Dict[str, bool]: def to_dict(self) -> Dict[str, bool]:
"""Convert to dictionary with string keys for storage""" """Convert to dictionary with string keys for storage"""
return {event.value: enabled for event, enabled in self.settings.items()} return {event.value: enabled for event, enabled in self.settings.items()}
@classmethod @classmethod
def get_defaults(cls) -> Dict[str, bool]: def get_defaults(cls) -> Dict[str, bool]:
"""Get the default settings as a dictionary with string keys""" """Get the default settings as a dictionary with string keys"""
return cls().to_dict() return cls().to_dict()

View File

@ -6,6 +6,7 @@ Always uses fastuuid for performance.
import fastuuid as _uuid # type: ignore import fastuuid as _uuid # type: ignore
# Expose a module-like alias so callers can use: uuid.uuid4() # Expose a module-like alias so callers can use: uuid.uuid4()
uuid = _uuid uuid = _uuid

View File

@ -9,6 +9,7 @@ from typing import Dict, Optional
from .exceptions import AnthropicErrorResponse, AnthropicErrorType from .exceptions import AnthropicErrorResponse, AnthropicErrorType
# HTTP status code -> Anthropic error type # HTTP status code -> Anthropic error type
# Source: https://docs.anthropic.com/en/api/errors # Source: https://docs.anthropic.com/en/api/errors
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = { ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {

View File

@ -2,6 +2,7 @@
from typing_extensions import Literal, Required, TypedDict from typing_extensions import Literal, Required, TypedDict
# Known Anthropic error types # Known Anthropic error types
# Source: https://docs.anthropic.com/en/api/errors # Source: https://docs.anthropic.com/en/api/errors
AnthropicErrorType = Literal[ AnthropicErrorType = Literal[

View File

@ -97,7 +97,7 @@ def _build_reasoning_item(
def _reasoning_item_to_response_input( def _reasoning_item_to_response_input(
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]], r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Convert a stored ChatCompletionReasoningItem back to a Responses API input item.""" """Convert a stored ChatCompletionReasoningItem back to a Responses API input item."""
r_input: Dict[str, Any] = { r_input: Dict[str, Any] = {

View File

@ -5,6 +5,7 @@ Auto-detect content type per message: code, JSON, or text.
import json import json
import re import re
_CODE_KEYWORDS = re.compile( _CODE_KEYWORDS = re.compile(
r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b" r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b"
) )

View File

@ -1,5 +1,6 @@
from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union
FileContentProvider = Literal[ FileContentProvider = Literal[
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus" "openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
] ]

View File

@ -1,10 +1,10 @@
""" """
Google GenAI Adapters for LiteLLM Google GenAI Adapters for LiteLLM
This module provides adapters for transforming Google GenAI generate_content requests This module provides adapters for transforming Google GenAI generate_content requests
to/from LiteLLM completion format with full support for: to/from LiteLLM completion format with full support for:
- Text content transformation - Text content transformation
- Tool calling (function declarations, function calls, function responses) - Tool calling (function declarations, function calls, function responses)
- Streaming (both regular and tool calling) - Streaming (both regular and tool calling)
- Mixed content (text + tool calls) - Mixed content (text + tool calls)
""" """

View File

@ -1,9 +1,9 @@
""" """
Handles Batching + sending Httpx Post requests to slack Handles Batching + sending Httpx Post requests to slack
Slack alerts are sent every 10s or when events are greater than X events Slack alerts are sent every 10s or when events are greater than X events
see custom_batch_logger.py for more details / defaults see custom_batch_logger.py for more details / defaults
""" """
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any

View File

@ -18,7 +18,7 @@ else:
def process_slack_alerting_variables( def process_slack_alerting_variables(
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]], alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
) -> Optional[Dict[AlertType, Union[List[str], str]]]: ) -> Optional[Dict[AlertType, Union[List[str], str]]]:
""" """
process alert_to_webhook_url process alert_to_webhook_url

View File

@ -1,5 +1,5 @@
""" """
Base class for Additional Logging Utils for CustomLoggers Base class for Additional Logging Utils for CustomLoggers
- Health Check for the logging util - Health Check for the logging util
- Get Request / Response Payload for the logging util - Get Request / Response Payload for the logging util

View File

@ -1,5 +1,5 @@
""" """
Custom Logger that handles batching logic Custom Logger that handles batching logic
Use this if you want your logs to be stored in memory and flushed periodically. Use this if you want your logs to be stored in memory and flushed periodically.
""" """

View File

@ -9,6 +9,7 @@ import polars as pl
from .schema import FOCUS_NORMALIZED_SCHEMA from .schema import FOCUS_NORMALIZED_SCHEMA
_TAG_KEYS = ( _TAG_KEYS = (
"team_id", "team_id",
"team_alias", "team_alias",

View File

@ -105,7 +105,7 @@ def _remove_nulls(x: Dict[str, Any]) -> Dict[str, Any]:
def get_traces_and_spans_from_payload( def get_traces_and_spans_from_payload(
payload: List[Dict[str, Any]], payload: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
""" """
Separate traces and spans from payload. Separate traces and spans from payload.

View File

@ -1,8 +1,8 @@
""" """
s3 Bucket Logging Integration s3 Bucket Logging Integration
async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
""" """

View File

@ -5,28 +5,28 @@ This module provides SDK methods for Google's Interactions API.
Usage: Usage:
import litellm import litellm
# Create an interaction with a model # Create an interaction with a model
response = litellm.interactions.create( response = litellm.interactions.create(
model="gemini-2.5-flash", model="gemini-2.5-flash",
input="Hello, how are you?" input="Hello, how are you?"
) )
# Create an interaction with an agent # Create an interaction with an agent
response = litellm.interactions.create( response = litellm.interactions.create(
agent="deep-research-pro-preview-12-2025", agent="deep-research-pro-preview-12-2025",
input="Research the current state of cancer research" input="Research the current state of cancer research"
) )
# Async version # Async version
response = await litellm.interactions.acreate(...) response = await litellm.interactions.acreate(...)
# Get an interaction # Get an interaction
response = litellm.interactions.get(interaction_id="...") response = litellm.interactions.get(interaction_id="...")
# Delete an interaction # Delete an interaction
result = litellm.interactions.delete(interaction_id="...") result = litellm.interactions.delete(interaction_id="...")
# Cancel an interaction # Cancel an interaction
result = litellm.interactions.cancel(interaction_id="...") result = litellm.interactions.cancel(interaction_id="...")

View File

@ -8,25 +8,25 @@ Per OpenAPI spec (https://ai.google.dev/static/api/interactions.openapi.json):
Usage: Usage:
import litellm import litellm
# Create an interaction with a model # Create an interaction with a model
response = litellm.interactions.create( response = litellm.interactions.create(
model="gemini-2.5-flash", model="gemini-2.5-flash",
input="Hello, how are you?" input="Hello, how are you?"
) )
# Create an interaction with an agent # Create an interaction with an agent
response = litellm.interactions.create( response = litellm.interactions.create(
agent="deep-research-pro-preview-12-2025", agent="deep-research-pro-preview-12-2025",
input="Research the current state of cancer research" input="Research the current state of cancer research"
) )
# Async version # Async version
response = await litellm.interactions.acreate(...) response = await litellm.interactions.acreate(...)
# Get an interaction # Get an interaction
response = litellm.interactions.get(interaction_id="...") response = litellm.interactions.get(interaction_id="...")
# Delete an interaction # Delete an interaction
result = litellm.interactions.delete(interaction_id="...") result = litellm.interactions.delete(interaction_id="...")
""" """

View File

@ -994,8 +994,10 @@ class Logging(LiteLLMLoggingBaseClass):
try: try:
# [Non-blocking Extra Debug Information in metadata] # [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True: if turn_off_message_logging is True:
_metadata["raw_request"] = "redacted by litellm. \ _metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'" 'litellm.turn_off_message_logging=True'"
)
else: else:
curl_command = self._get_request_curl_command( curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""), api_base=additional_args.get("api_base", ""),
@ -1029,8 +1031,12 @@ class Logging(LiteLLMLoggingBaseClass):
error=str(e), error=str(e),
) )
) )
_metadata["raw_request"] = "Unable to Log \ _metadata["raw_request"] = (
raw request: {}".format(str(e)) "Unable to Log \
raw request: {}".format(
str(e)
)
)
if getattr(self, "logger_fn", None) and callable(self.logger_fn): if getattr(self, "logger_fn", None) and callable(self.logger_fn):
try: try:
self.logger_fn( self.logger_fn(

View File

@ -5533,7 +5533,9 @@ def default_response_schema_prompt(response_schema: dict) -> str:
prompt_str = """Use this JSON schema: prompt_str = """Use this JSON schema:
```json ```json
{} {}
```""".format(response_schema) ```""".format(
response_schema
)
return prompt_str return prompt_str

View File

@ -1,9 +1,9 @@
""" """
This is a cache for LangfuseLoggers. This is a cache for LangfuseLoggers.
Langfuse Python SDK initializes a thread for each client. Langfuse Python SDK initializes a thread for each client.
This ensures we do This ensures we do
1. Proper cleanup of Langfuse initialized clients. 1. Proper cleanup of Langfuse initialized clients.
2. Re-use created langfuse clients. 2. Re-use created langfuse clients.
""" """

View File

@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, cast
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# SSE parsing helpers (module-level to keep the class lean) # SSE parsing helpers (module-level to keep the class lean)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -4,10 +4,10 @@ Support for o1 and o3 model families
https://platform.openai.com/docs/guides/reasoning https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM: Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param) - modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user' - role: system ==> translate to role 'user'
- streaming => faked by LiteLLM - streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param) - Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param) - Logprobs => drop param (if user opts in to dropping param)
- Temperature => drop param (if user opts in to dropping param) - Temperature => drop param (if user opts in to dropping param)
""" """

View File

@ -1,5 +1,5 @@
""" """
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed. Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works

View File

@ -1,5 +1,5 @@
""" """
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
""" """
from typing import Optional from typing import Optional

View File

@ -1,5 +1,5 @@
""" """
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format. Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works

View File

@ -1,5 +1,5 @@
""" """
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format. Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """

View File

@ -16,6 +16,7 @@ from litellm.secret_managers.main import get_secret_str
from ...openai_like.chat.transformation import OpenAILikeChatConfig from ...openai_like.chat.transformation import OpenAILikeChatConfig
BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1" BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1"

View File

@ -1,5 +1,5 @@
""" """
Legacy /v1/embedding handler for Bedrock Cohere. Legacy /v1/embedding handler for Bedrock Cohere.
""" """
import json import json

View File

@ -13,6 +13,7 @@ from typing import Tuple
import httpx import httpx
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Pre-built response templates # Pre-built response templates
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -1,5 +1,5 @@
""" """
Cost calculator for Dashscope Chat models. Cost calculator for Dashscope Chat models.
Handles tiered pricing and prompt caching scenarios. Handles tiered pricing and prompt caching scenarios.
""" """

View File

@ -1,5 +1,5 @@
""" """
Support for OpenAI's `/v1/chat/completions` endpoint. Support for OpenAI's `/v1/chat/completions` endpoint.
Calls done in OpenAI/openai.py as DataRobot is openai-compatible. Calls done in OpenAI/openai.py as DataRobot is openai-compatible.
""" """

View File

@ -1,5 +1,5 @@
""" """
Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format. Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format.
""" """
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union

View File

@ -1,5 +1,5 @@
""" """
Cost calculator for DeepSeek Chat models. Cost calculator for DeepSeek Chat models.
Handles prompt caching scenario. Handles prompt caching scenario.
""" """

View File

@ -22,6 +22,7 @@ from litellm.types.utils import all_litellm_params
from ..common_utils import ElevenLabsException from ..common_utils import ElevenLabsException
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent

View File

@ -55,7 +55,7 @@ def _convert_image_to_gemini_format(image_file) -> Dict[str, str]:
def _usage_video_resolution_from_parameters( def _usage_video_resolution_from_parameters(
parameters: Dict[str, Any], parameters: Dict[str, Any]
) -> Optional[str]: ) -> Optional[str]:
"""Normalize Veo ``parameters.resolution`` for usage and cost tracking.""" """Normalize Veo ``parameters.resolution`` for usage and cost tracking."""
res = parameters.get("resolution") res = parameters.get("resolution")

View File

@ -1,5 +1,5 @@
""" """
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format. Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """

View File

@ -1,5 +1,5 @@
""" """
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format. Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works

View File

@ -1,5 +1,5 @@
""" """
Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format. Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works

View File

@ -1,5 +1,5 @@
""" """
Support for OpenAI's `/v1/chat/completions` endpoint. Support for OpenAI's `/v1/chat/completions` endpoint.
Calls done in OpenAI/openai.py as Novita AI is openai-compatible. Calls done in OpenAI/openai.py as Novita AI is openai-compatible.

View File

@ -1,7 +1,7 @@
""" """
Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
This is OpenAI compatible This is OpenAI compatible
This file only contains param mapping logic This file only contains param mapping logic

View File

@ -1,7 +1,7 @@
""" """
Nvidia NIM embeddings endpoint: https://docs.api.nvidia.com/nim/reference/nvidia-nv-embedqa-e5-v5-infer Nvidia NIM embeddings endpoint: https://docs.api.nvidia.com/nim/reference/nvidia-nv-embedqa-e5-v5-infer
This is OpenAI compatible This is OpenAI compatible
This file only contains param mapping logic This file only contains param mapping logic

View File

@ -1,14 +1,14 @@
""" """
Support for o1/o3 model family Support for o1/o3 model family
https://platform.openai.com/docs/guides/reasoning https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM: Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param) - modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user' - role: system ==> translate to role 'user'
- streaming => faked by LiteLLM - streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param) - Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param) - Logprobs => drop param (if user opts in to dropping param)
""" """
from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload

View File

@ -201,7 +201,7 @@ class BaseOpenAILLM:
@staticmethod @staticmethod
def get_openai_client_initialization_param_fields( def get_openai_client_initialization_param_fields(
client_type: Literal["openai", "azure"], client_type: Literal["openai", "azure"]
) -> Tuple[str, ...]: ) -> Tuple[str, ...]:
"""Returns a tuple of fields that are used to initialize the OpenAI client""" """Returns a tuple of fields that are used to initialize the OpenAI client"""
if client_type == "openai": if client_type == "openai":

View File

@ -49,6 +49,7 @@ from litellm.types.utils import (
) )
from litellm.llms.openrouter.common_utils import OpenRouterException from litellm.llms.openrouter.common_utils import OpenRouterException
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else: else:

View File

@ -1,7 +1,7 @@
""" """
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke` Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`
In the Huggingface TGI format. In the Huggingface TGI format.
""" """
import json import json

View File

@ -1,7 +1,7 @@
""" """
Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke` Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke`
In the Huggingface TGI format. In the Huggingface TGI format.
""" """
from typing import TYPE_CHECKING, Any, List, Optional, Union from typing import TYPE_CHECKING, Any, List, Optional, Union

View File

@ -207,7 +207,7 @@ def resolve_resource_group(sources: List[Source]) -> Optional[str]:
def _parse_service_key_once( def _parse_service_key_once(
service_key: Optional[Union[str, dict]], service_key: Optional[Union[str, dict]]
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Pre-parse service_key if it's a string to avoid repeated JSON parsing. Pre-parse service_key if it's a string to avoid repeated JSON parsing.

View File

@ -14,6 +14,7 @@ from ...openai_like.chat.transformation import OpenAIGPTConfig
from ..utils import SnowflakeBaseConfig from ..utils import SnowflakeBaseConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

View File

@ -1,5 +1,5 @@
""" """
Support for OpenAI's `/v1/chat/completions` endpoint. Support for OpenAI's `/v1/chat/completions` endpoint.
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible. Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.

View File

@ -1,5 +1,5 @@
""" """
Support for OpenAI's `/v1/embeddings` endpoint. Support for OpenAI's `/v1/embeddings` endpoint.
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible. Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.

View File

@ -1,5 +1,5 @@
""" """
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format. Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """

View File

@ -1,5 +1,5 @@
""" """
Transformation logic for context caching. Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """
@ -19,7 +19,7 @@ from ..gemini.transformation import (
def get_first_continuous_block_idx( def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]], # (idx, message) filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int: ) -> int:
""" """
Find the array index that ends the first continuous sequence of message blocks. Find the array index that ends the first continuous sequence of message blocks.

View File

@ -632,14 +632,16 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
contents.append(ContentType(role="user", parts=tool_call_responses)) contents.append(ContentType(role="user", parts=tool_call_responses))
if len(contents) == 0: if len(contents) == 0:
verbose_logger.warning(""" verbose_logger.warning(
"""
No contents in messages. Contents are required. See No contents in messages. Contents are required. See
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body. https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
If the original request did not comply to OpenAI API requirements it should have failed by now, If the original request did not comply to OpenAI API requirements it should have failed by now,
but LiteLLM does not check for missing messages. but LiteLLM does not check for missing messages.
Setting an empty content to prevent an 400 error. Setting an empty content to prevent an 400 error.
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733 Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
""") """
)
contents.append(ContentType(role="user", parts=[PartType(text=" ")])) contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
return contents return contents
except Exception as e: except Exception as e:

View File

@ -1,5 +1,5 @@
""" """
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """

View File

@ -139,7 +139,7 @@ class VertexTextToSpeechAPI(VertexLLM):
########## End of logging ############ ########## End of logging ############
####### Send the request ################### ####### Send the request ###################
if _is_async is True: if _is_async is True:
return self.async_audio_speech( # type: ignore return self.async_audio_speech( # type:ignore
logging_obj=logging_obj, url=url, headers=headers, request=request logging_obj=logging_obj, url=url, headers=headers, request=request
) )
sync_handler = _get_httpx_client() sync_handler = _get_httpx_client()

View File

@ -1,5 +1,5 @@
""" """
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`. Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead. NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
""" """

View File

@ -1,6 +1,6 @@
""" """
This module is used to transform the request and response for the Voyage contextualized embeddings API. This module is used to transform the request and response for the Voyage contextualized embeddings API.
This would be used for all the contextualized embeddings models in Voyage. This would be used for all the contextualized embeddings models in Voyage.
""" """
from typing import List, Optional, Union from typing import List, Optional, Union

View File

@ -324,7 +324,7 @@ class CustomOpenAPISpec:
@staticmethod @staticmethod
def add_chat_completion_request_schema( def add_chat_completion_request_schema(
openapi_schema: Dict[str, Any], openapi_schema: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation. Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
@ -380,7 +380,7 @@ class CustomOpenAPISpec:
@staticmethod @staticmethod
def add_responses_api_request_schema( def add_responses_api_request_schema(
openapi_schema: Dict[str, Any], openapi_schema: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation. Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
@ -410,7 +410,7 @@ class CustomOpenAPISpec:
@staticmethod @staticmethod
def add_llm_api_request_schema_body( def add_llm_api_request_schema_body(
openapi_schema: Dict[str, Any], openapi_schema: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Add LLM API request schema bodies to OpenAPI specification for documentation. Add LLM API request schema bodies to OpenAPI specification for documentation.

View File

@ -257,7 +257,7 @@ async def get_form_data(request: Request) -> Dict[str, Any]:
async def convert_upload_files_to_file_data( async def convert_upload_files_to_file_data(
form_data: Dict[str, Any], form_data: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Convert FastAPI UploadFile objects to file data tuples for litellm. Convert FastAPI UploadFile objects to file data tuples for litellm.

View File

@ -1,5 +1,5 @@
""" """
Contains utils used by OpenAI compatible endpoints Contains utils used by OpenAI compatible endpoints
""" """
from typing import Optional, Set from typing import Optional, Set

View File

@ -1,5 +1,5 @@
""" """
What is this? What is this?
CRUD endpoints for managing pass-through endpoints CRUD endpoints for managing pass-through endpoints
""" """

View File

@ -34,7 +34,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise raise
# If an error occurs, the view does not exist, so create it # If an error occurs, the view does not exist, so create it
await db.execute_raw(""" await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT SELECT
v.*, v.*,
@ -46,7 +47,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
FROM "LiteLLM_VerificationToken" v FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id; LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
""") """
)
verbose_logger.debug("LiteLLM_VerificationTokenView Created!") verbose_logger.debug("LiteLLM_VerificationTokenView Created!")

View File

@ -10,6 +10,7 @@ every text fragment.
from typing import Any, Callable, Dict, FrozenSet, Iterator, List from typing import Any, Callable, Dict, FrozenSet, Iterator, List
# Call types whose body carries free-form chat / prompt text that # Call types whose body carries free-form chat / prompt text that
# text-content guardrails (banned keywords, content moderation, secret # text-content guardrails (banned keywords, content moderation, secret
# detection, …) should inspect. The proxy ingress passes ``route_type`` # detection, …) should inspect. The proxy ingress passes ``route_type``

View File

@ -4,6 +4,7 @@ from litellm.types.guardrails import SupportedGuardrailIntegrations
from .akto import AktoGuardrail from .akto import AktoGuardrail
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams from litellm.types.guardrails import Guardrail, LitellmParams

View File

@ -6,7 +6,7 @@ The actual skill logic is in litellm/llms/litellm_proxy/skills/.
Usage: Usage:
from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook
# Register hook in proxy # Register hook in proxy
litellm.callbacks.append(SkillsInjectionHook()) litellm.callbacks.append(SkillsInjectionHook())
""" """

View File

@ -1,9 +1,9 @@
""" """
BUDGET MANAGEMENT BUDGET MANAGEMENT
All /budget management endpoints All /budget management endpoints
/budget/new /budget/new
/budget/info /budget/info
/budget/update /budget/update
/budget/delete /budget/delete

View File

@ -1,9 +1,9 @@
""" """
CUSTOMER MANAGEMENT CUSTOMER MANAGEMENT
All /customer management endpoints All /customer management endpoints
/customer/new /customer/new
/customer/info /customer/info
/customer/update /customer/update
/customer/delete /customer/delete

View File

@ -529,7 +529,7 @@ async def _update_existing_team_model_assignment(
""" """
def _get_team_public_model_name( def _get_team_public_model_name(
model_info: Optional[Union[dict, str]], model_info: Optional[Union[dict, str]]
) -> Optional[str]: ) -> Optional[str]:
if isinstance(model_info, dict): if isinstance(model_info, dict):
value = model_info.get("team_public_model_name") value = model_info.get("team_public_model_name")

View File

@ -7,7 +7,7 @@ variables.
Environment Variables: Environment Variables:
- MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL - MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL - MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
- MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL - MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL
If these are not set, the default Microsoft endpoints are used. If these are not set, the default Microsoft endpoints are used.

View File

@ -4347,7 +4347,9 @@ async def list_team(
except Exception as e: except Exception as e:
team_exception = """Invalid team object for team_id: {}. team_object={}. team_exception = """Invalid team object for team_id: {}. team_object={}.
Error: {} Error: {}
""".format(team.team_id, team.model_dump(), str(e)) """.format(
team.team_id, team.model_dump(), str(e)
)
verbose_proxy_logger.exception(team_exception) verbose_proxy_logger.exception(team_exception)
continue continue
# Sort the responses by team_alias # Sort the responses by team_alias

View File

@ -3,7 +3,7 @@ User Agent Analytics Endpoints
This module provides optimized endpoints for tracking user agent activity metrics including: This module provides optimized endpoints for tracking user agent activity metrics including:
- Daily Active Users (DAU) by tags for configurable number of days - Daily Active Users (DAU) by tags for configurable number of days
- Weekly Active Users (WAU) by tags for configurable number of weeks - Weekly Active Users (WAU) by tags for configurable number of weeks
- Monthly Active Users (MAU) by tags for configurable number of months - Monthly Active Users (MAU) by tags for configurable number of months
- Summary analytics by tags - Summary analytics by tags

View File

@ -18,6 +18,7 @@ from litellm.litellm_core_utils.litellm_logging import (
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import StandardPassThroughResponseObject from litellm.types.utils import StandardPassThroughResponseObject
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = { CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
"POST /v0/agents": "cursor:agent:create", "POST /v0/agents": "cursor:agent:create",
"GET /v0/agents": "cursor:agent:list", "GET /v0/agents": "cursor:agent:list",

View File

@ -292,7 +292,9 @@ class ProxyInitializationHelpers:
_endpoint_str = ( _endpoint_str = (
f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\" f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\"
) )
curl_command = _endpoint_str + """ curl_command = (
_endpoint_str
+ """
--header 'Content-Type: application/json' \\ --header 'Content-Type: application/json' \\
--data ' { --data ' {
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
@ -305,6 +307,7 @@ class ProxyInitializationHelpers:
}' }'
\n \n
""" """
)
print() # noqa print() # noqa
print( # noqa print( # noqa
'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n' '\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n'
@ -380,9 +383,11 @@ class ProxyInitializationHelpers:
with open(os.devnull, "w") as devnull: with open(os.devnull, "w") as devnull:
subprocess.Popen(command, stdout=devnull, stderr=devnull) subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e: except Exception as e:
print(f""" print( # noqa
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") # noqa # noqa """
) # noqa
@staticmethod @staticmethod
def _is_port_in_use(port): def _is_port_in_use(port):

View File

@ -2688,9 +2688,11 @@ def run_ollama_serve():
with open(os.devnull, "w") as devnull: with open(os.devnull, "w") as devnull:
subprocess.Popen(command, stdout=devnull, stderr=devnull) subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug(f""" verbose_proxy_logger.debug(
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") """
)
def _get_process_rss_mb() -> Optional[float]: def _get_process_rss_mb() -> Optional[float]:

View File

@ -3184,14 +3184,16 @@ async def provider_budgets() -> ProviderBudgetResponse:
async def get_spend_by_tags( async def get_spend_by_tags(
prisma_client: PrismaClient, start_date=None, end_date=None prisma_client: PrismaClient, start_date=None, end_date=None
): ):
response = await prisma_client.db.query_raw(""" response = await prisma_client.db.query_raw(
"""
SELECT SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag, jsonb_array_elements_text(request_tags) AS individual_request_tag,
COUNT(*) AS log_count, COUNT(*) AS log_count,
SUM(spend) AS total_spend SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" FROM "LiteLLM_SpendLogs"
GROUP BY individual_request_tag; GROUP BY individual_request_tag;
""") """
)
return response return response

View File

@ -2712,7 +2712,8 @@ class PrismaClient:
required_view = "LiteLLM_VerificationTokenView" required_view = "LiteLLM_VerificationTokenView"
expected_views_str = ", ".join(f"'{view}'" for view in expected_views) expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
pg_schema = os.getenv("DATABASE_SCHEMA", "public") pg_schema = os.getenv("DATABASE_SCHEMA", "public")
ret = await self.db.query_raw(f""" ret = await self.db.query_raw(
f"""
WITH existing_views AS ( WITH existing_views AS (
SELECT viewname SELECT viewname
FROM pg_views FROM pg_views
@ -2724,7 +2725,8 @@ class PrismaClient:
(SELECT COUNT(*) FROM existing_views) AS view_count, (SELECT COUNT(*) FROM existing_views) AS view_count,
ARRAY_AGG(viewname) AS view_names ARRAY_AGG(viewname) AS view_names
FROM existing_views FROM existing_views
""") """
)
expected_total_views = len(expected_views) expected_total_views = len(expected_views)
if ret[0]["view_count"] == expected_total_views: if ret[0]["view_count"] == expected_total_views:
verbose_proxy_logger.info("All necessary views exist!") verbose_proxy_logger.info("All necessary views exist!")
@ -2733,7 +2735,8 @@ class PrismaClient:
## check if required view exists ## ## check if required view exists ##
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]: if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
await self.health_check() # make sure we can connect to db await self.health_check() # make sure we can connect to db
await self.db.execute_raw(""" await self.db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT SELECT
v.*, v.*,
@ -2743,7 +2746,8 @@ class PrismaClient:
t.rpm_limit AS team_rpm_limit t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
""") """
)
verbose_proxy_logger.info( verbose_proxy_logger.info(
"LiteLLM_VerificationTokenView Created in DB!" "LiteLLM_VerificationTokenView Created in DB!"

View File

@ -1,5 +1,5 @@
""" """
What is this? What is this?
Logging Pass-Through Endpoints Logging Pass-Through Endpoints
""" """

View File

@ -826,7 +826,7 @@ class Router:
@staticmethod @staticmethod
def _normalize_strategy( def _normalize_strategy(
strategy: Union[RoutingStrategy, str, None], strategy: Union[RoutingStrategy, str, None]
) -> Optional[str]: ) -> Optional[str]:
if strategy is None: if strategy is None:
return None return None

View File

@ -103,7 +103,7 @@ def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str
def _recent_tool_results( def _recent_tool_results(
messages: Optional[List[Dict[str, Any]]], messages: Optional[List[Dict[str, Any]]]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Extract the current turn's tool result payloads from the request messages. """Extract the current turn's tool result payloads from the request messages.

Some files were not shown because too many files have changed in this diff Show More