Drop dep bumps + black-26 reformat to clear fork CI policy

PR was blocked by .github/workflows/guard-fork-dependencies.yml: fork PRs
cannot modify uv.lock. Reverting:

- uv.lock + pyproject.toml black bump (24.10.0 -> 26.3.1) and the 295
  files of mechanical Black 26 reformat coupled to it
- pyproject.toml diskcache extra change (kept the runtime mitigation in
  litellm/caching/disk_cache.py via JSONDisk)

Kept:
- Dockerfile cache narrowing (drops ~660 MB of uv build cache that
  surfaced cached setuptools as CVE findings)
- litellm/caching/disk_cache.py: dc.JSONDisk to neutralize CVE-2025-69872
- ui/litellm-dashboard/package-lock.json + litellm-js/spend-logs/package-lock.json:
  next/postcss/hono/uuid CVE bumps (these are not blocked by the fork guard)
- tests/test_litellm/caching/test_disk_cache.py
- tests/code_coverage_tests/liccheck.ini: harmless black authorization

Black + gitpython + langchain dep upgrades will need a follow-up from a
maintainer pushing a branch in the canonical BerriAI/litellm repo.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
user 2026-05-07 23:04:52 +00:00
parent 63bda3f001
commit 5bafa8b3a2
No known key found for this signature in database
292 changed files with 1814 additions and 2305 deletions

View File

@ -2,7 +2,6 @@ import asyncio
import aiohttp
import json
# Asynchronously fetch data from a given URL
async def fetch_data(url):
try:
@ -16,24 +15,22 @@ async def fetch_data(url):
resp_json = await resp.json()
print("Fetch the data from URL.")
# Return the 'data' field from the JSON response
return resp_json["data"]
return resp_json['data']
except Exception as e:
# Print an error message if fetching data fails
print("Error fetching data from URL:", e)
return None
# Synchronize local data with remote data
def sync_local_data_with_remote(local_data, remote_data):
# Update existing keys in local_data with values from remote_data
for key in set(local_data) & set(remote_data):
for key in (set(local_data) & set(remote_data)):
local_data[key].update(remote_data[key])
# Add new keys from remote_data to local_data
for key in set(remote_data) - set(local_data):
for key in (set(remote_data) - set(local_data)):
local_data[key] = remote_data[key]
# Write data to the json file
def write_to_file(file_path, data):
try:
@ -46,7 +43,6 @@ def write_to_file(file_path, data):
# Print an error message if writing to file fails
print("Error updating JSON file:", e)
# Update the existing models and add the missing models for OpenRouter
def transform_openrouter_data(data):
transformed = {}
@ -58,41 +54,33 @@ def transform_openrouter_data(data):
}
# Add 'max_output_tokens' as a field if it is not None
if (
"top_provider" in row
and "max_completion_tokens" in row["top_provider"]
and row["top_provider"]["max_completion_tokens"] is not None
):
obj["max_output_tokens"] = int(row["top_provider"]["max_completion_tokens"])
if "top_provider" in row and "max_completion_tokens" in row["top_provider"] and row["top_provider"]["max_completion_tokens"] is not None:
obj['max_output_tokens'] = int(row["top_provider"]["max_completion_tokens"])
# Add the field 'output_cost_per_token'
obj.update(
{
"output_cost_per_token": float(row["pricing"]["completion"]),
}
)
obj.update({
"output_cost_per_token": float(row["pricing"]["completion"]),
})
# Add field 'input_cost_per_image' if it exists and is non-zero
if (
"pricing" in row
and "image" in row["pricing"]
and float(row["pricing"]["image"]) != 0.0
):
obj["input_cost_per_image"] = float(row["pricing"]["image"])
if "pricing" in row and "image" in row["pricing"] and float(row["pricing"]["image"]) != 0.0:
obj['input_cost_per_image'] = float(row["pricing"]["image"])
# Add the fields 'litellm_provider' and 'mode'
obj.update({"litellm_provider": "openrouter", "mode": "chat"})
obj.update({
"litellm_provider": "openrouter",
"mode": "chat"
})
# Add the 'supports_vision' field if the modality is 'multimodal'
if row.get("architecture", {}).get("modality") == "multimodal":
obj["supports_vision"] = True
if row.get('architecture', {}).get('modality') == 'multimodal':
obj['supports_vision'] = True
# Use a composite key to store the transformed object
transformed[f'openrouter/{row["id"]}'] = obj
return transformed
# Update the existing models and add the missing models for Vercel AI Gateway
def transform_vercel_ai_gateway_data(data):
transformed = {}
@ -101,27 +89,17 @@ def transform_vercel_ai_gateway_data(data):
"max_tokens": row["context_window"],
"input_cost_per_token": float(row["pricing"]["input"]),
"output_cost_per_token": float(row["pricing"]["output"]),
"max_output_tokens": row["max_tokens"],
"max_input_tokens": row["context_window"],
'max_output_tokens': row['max_tokens'],
'max_input_tokens': row["context_window"],
}
# Handle cache pricing if available
if "pricing" in row:
if (
"input_cache_read" in row["pricing"]
and row["pricing"]["input_cache_read"] is not None
):
obj["cache_read_input_token_cost"] = float(
f"{float(row['pricing']['input_cache_read']):e}"
)
if "input_cache_read" in row["pricing"] and row["pricing"]["input_cache_read"] is not None:
obj['cache_read_input_token_cost'] = float(f"{float(row['pricing']['input_cache_read']):e}")
if (
"input_cache_write" in row["pricing"]
and row["pricing"]["input_cache_write"] is not None
):
obj["cache_creation_input_token_cost"] = float(
f"{float(row['pricing']['input_cache_write']):e}"
)
if "input_cache_write" in row["pricing"] and row["pricing"]["input_cache_write"] is not None:
obj['cache_creation_input_token_cost'] = float(f"{float(row['pricing']['input_cache_write']):e}")
mode = "embedding" if "embedding" in row["id"].lower() else "chat"
@ -148,17 +126,10 @@ def load_local_data(file_path):
print("Error decoding JSON:", e)
return None
def main():
local_file_path = (
"model_prices_and_context_window.json" # Path to the local data file
)
openrouter_url = (
"https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
)
vercel_ai_gateway_url = (
"https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
)
local_file_path = "model_prices_and_context_window.json" # Path to the local data file
openrouter_url = "https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
vercel_ai_gateway_url = "https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
# Load local data from file
local_data = load_local_data(local_file_path)
@ -183,7 +154,6 @@ def main():
else:
print("Failed to fetch model data from either local file or URL.")
# Entry point of the script
if __name__ == "__main__":
main()

View File

@ -16,55 +16,52 @@ from pathlib import Path
import json
from typing import Dict, List, Tuple, Optional
# ANSI color codes for terminal output
class Colors:
GREEN = "\033[92m"
RED = "\033[91m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
PURPLE = "\033[95m"
CYAN = "\033[96m"
RESET = "\033[0m"
BOLD = "\033[1m"
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
PURPLE = '\033[95m'
CYAN = '\033[96m'
RESET = '\033[0m'
BOLD = '\033[1m'
def print_colored(message: str, color: str = Colors.RESET):
"""Print colored message to terminal"""
print(f"{color}{message}{Colors.RESET}")
def get_provider_from_test_file(test_file: str) -> str:
"""Map test file names to provider names"""
provider_mapping = {
"test_anthropic": "Anthropic",
"test_azure": "Azure",
"test_bedrock": "AWS Bedrock",
"test_openai": "OpenAI",
"test_vertex": "Google Vertex AI",
"test_gemini": "Google Vertex AI",
"test_cohere": "Cohere",
"test_databricks": "Databricks",
"test_groq": "Groq",
"test_together": "Together AI",
"test_mistral": "Mistral",
"test_deepseek": "DeepSeek",
"test_replicate": "Replicate",
"test_huggingface": "HuggingFace",
"test_fireworks": "Fireworks AI",
"test_perplexity": "Perplexity",
"test_cloudflare": "Cloudflare",
"test_voyage": "Voyage AI",
"test_xai": "xAI",
"test_nvidia": "NVIDIA",
"test_watsonx": "IBM watsonx",
"test_azure_ai": "Azure AI",
"test_snowflake": "Snowflake",
"test_infinity": "Infinity",
"test_jina": "Jina AI",
"test_deepgram": "Deepgram",
"test_clarifai": "Clarifai",
"test_triton": "Triton",
'test_anthropic': 'Anthropic',
'test_azure': 'Azure',
'test_bedrock': 'AWS Bedrock',
'test_openai': 'OpenAI',
'test_vertex': 'Google Vertex AI',
'test_gemini': 'Google Vertex AI',
'test_cohere': 'Cohere',
'test_databricks': 'Databricks',
'test_groq': 'Groq',
'test_together': 'Together AI',
'test_mistral': 'Mistral',
'test_deepseek': 'DeepSeek',
'test_replicate': 'Replicate',
'test_huggingface': 'HuggingFace',
'test_fireworks': 'Fireworks AI',
'test_perplexity': 'Perplexity',
'test_cloudflare': 'Cloudflare',
'test_voyage': 'Voyage AI',
'test_xai': 'xAI',
'test_nvidia': 'NVIDIA',
'test_watsonx': 'IBM watsonx',
'test_azure_ai': 'Azure AI',
'test_snowflake': 'Snowflake',
'test_infinity': 'Infinity',
'test_jina': 'Jina AI',
'test_deepgram': 'Deepgram',
'test_clarifai': 'Clarifai',
'test_triton': 'Triton',
}
for key, provider in provider_mapping.items():
@ -72,19 +69,11 @@ def get_provider_from_test_file(test_file: str) -> str:
return provider
# For cross-provider test files
if any(
name in test_file
for name in [
"test_optional_params",
"test_prompt_factory",
"test_router",
"test_text_completion",
]
):
return f"Cross-Provider Tests ({test_file})"
return "Other Tests"
if any(name in test_file for name in ['test_optional_params', 'test_prompt_factory',
'test_router', 'test_text_completion']):
return f'Cross-Provider Tests ({test_file})'
return 'Other Tests'
def format_duration(seconds: float) -> str:
"""Format duration in human-readable format"""
@ -100,17 +89,15 @@ def format_duration(seconds: float) -> str:
return f"{hours}h {minutes}m"
def generate_markdown_report(
junit_xml_path: str, output_path: str, tag: str = None, commit: str = None
):
def generate_markdown_report(junit_xml_path: str, output_path: str, tag: str = None, commit: str = None):
"""Generate a beautiful markdown report from JUnit XML"""
try:
tree = ET.parse(junit_xml_path)
root = tree.getroot()
# Handle both testsuite and testsuites root
if root.tag == "testsuites":
suites = root.findall("testsuite")
if root.tag == 'testsuites':
suites = root.findall('testsuite')
else:
suites = [root]
@ -122,87 +109,75 @@ def generate_markdown_report(
total_time = 0.0
# Provider breakdown
provider_stats = defaultdict(
lambda: {"passed": 0, "failed": 0, "skipped": 0, "errors": 0, "time": 0.0}
)
provider_stats = defaultdict(lambda: {'passed': 0, 'failed': 0, 'skipped': 0, 'errors': 0, 'time': 0.0})
provider_tests = defaultdict(list)
for suite in suites:
total_tests += int(suite.get("tests", 0))
total_failures += int(suite.get("failures", 0))
total_errors += int(suite.get("errors", 0))
total_skipped += int(suite.get("skipped", 0))
total_time += float(suite.get("time", 0))
total_tests += int(suite.get('tests', 0))
total_failures += int(suite.get('failures', 0))
total_errors += int(suite.get('errors', 0))
total_skipped += int(suite.get('skipped', 0))
total_time += float(suite.get('time', 0))
for testcase in suite.findall("testcase"):
classname = testcase.get("classname", "")
test_name = testcase.get("name", "")
test_time = float(testcase.get("time", 0))
for testcase in suite.findall('testcase'):
classname = testcase.get('classname', '')
test_name = testcase.get('name', '')
test_time = float(testcase.get('time', 0))
# Extract test file name from classname
if "." in classname:
parts = classname.split(".")
test_file = parts[-2] if len(parts) > 1 else "unknown"
if '.' in classname:
parts = classname.split('.')
test_file = parts[-2] if len(parts) > 1 else 'unknown'
else:
test_file = "unknown"
test_file = 'unknown'
provider = get_provider_from_test_file(test_file)
provider_stats[provider]["time"] += test_time
provider_stats[provider]['time'] += test_time
# Check test status
if testcase.find("failure") is not None:
provider_stats[provider]["failed"] += 1
failure = testcase.find("failure")
failure_msg = (
failure.get("message", "") if failure is not None else ""
)
provider_tests[provider].append(
{
"name": test_name,
"status": "FAILED",
"time": test_time,
"message": failure_msg,
}
)
elif testcase.find("error") is not None:
provider_stats[provider]["errors"] += 1
error = testcase.find("error")
error_msg = error.get("message", "") if error is not None else ""
provider_tests[provider].append(
{
"name": test_name,
"status": "ERROR",
"time": test_time,
"message": error_msg,
}
)
elif testcase.find("skipped") is not None:
provider_stats[provider]["skipped"] += 1
skip = testcase.find("skipped")
skip_msg = skip.get("message", "") if skip is not None else ""
provider_tests[provider].append(
{
"name": test_name,
"status": "SKIPPED",
"time": test_time,
"message": skip_msg,
}
)
if testcase.find('failure') is not None:
provider_stats[provider]['failed'] += 1
failure = testcase.find('failure')
failure_msg = failure.get('message', '') if failure is not None else ''
provider_tests[provider].append({
'name': test_name,
'status': 'FAILED',
'time': test_time,
'message': failure_msg
})
elif testcase.find('error') is not None:
provider_stats[provider]['errors'] += 1
error = testcase.find('error')
error_msg = error.get('message', '') if error is not None else ''
provider_tests[provider].append({
'name': test_name,
'status': 'ERROR',
'time': test_time,
'message': error_msg
})
elif testcase.find('skipped') is not None:
provider_stats[provider]['skipped'] += 1
skip = testcase.find('skipped')
skip_msg = skip.get('message', '') if skip is not None else ''
provider_tests[provider].append({
'name': test_name,
'status': 'SKIPPED',
'time': test_time,
'message': skip_msg
})
else:
provider_stats[provider]["passed"] += 1
provider_tests[provider].append(
{
"name": test_name,
"status": "PASSED",
"time": test_time,
"message": "",
}
)
provider_stats[provider]['passed'] += 1
provider_tests[provider].append({
'name': test_name,
'status': 'PASSED',
'time': test_time,
'message': ''
})
passed = total_tests - total_failures - total_errors - total_skipped
# Generate the markdown report
with open(output_path, "w") as f:
with open(output_path, 'w') as f:
# Header
f.write("# LLM Translation Test Results\n\n")
@ -211,9 +186,7 @@ def generate_markdown_report(
f.write("| Field | Value |\n")
f.write("|-------|-------|\n")
f.write(f"| **Tag** | `{tag or 'N/A'}` |\n")
f.write(
f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n"
)
f.write(f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n")
f.write(f"| **Commit** | `{commit or 'N/A'}` |\n")
f.write(f"| **Duration** | {format_duration(total_time)} |\n")
f.write("\n")
@ -224,34 +197,23 @@ def generate_markdown_report(
# Summary box
f.write("```\n")
f.write(f"Total Tests: {total_tests}\n")
f.write(
f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(
f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(
f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(
f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
)
f.write(f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write(f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write(f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write(f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
f.write("```\n\n")
# Provider summary table
f.write("## Results by Provider\n\n")
f.write(
"| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n"
)
f.write(
"|----------|-------|------|------|-------|------|-----------|----------|"
)
f.write("| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n")
f.write("|----------|-------|------|------|-------|------|-----------|----------|")
# Sort providers: specific providers first, then cross-provider tests
sorted_providers = []
cross_provider = []
for p in sorted(provider_stats.keys()):
if "Cross-Provider" in p or p == "Other Tests":
if 'Cross-Provider' in p or p == 'Other Tests':
cross_provider.append(p)
else:
sorted_providers.append(p)
@ -260,17 +222,10 @@ def generate_markdown_report(
for provider in all_providers:
stats = provider_stats[provider]
total = (
stats["passed"]
+ stats["failed"]
+ stats["errors"]
+ stats["skipped"]
)
pass_rate = (stats["passed"] / total * 100) if total > 0 else 0
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
pass_rate = (stats['passed'] / total * 100) if total > 0 else 0
f.write(
f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | "
)
f.write(f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | ")
f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ")
f.write(f"{format_duration(stats['time'])} |")
@ -280,58 +235,47 @@ def generate_markdown_report(
for provider in sorted_providers:
if provider_tests[provider]:
stats = provider_stats[provider]
total = (
stats["passed"]
+ stats["failed"]
+ stats["errors"]
+ stats["skipped"]
)
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
f.write(f"### {provider}\n\n")
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
f.write(
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) "
)
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) ")
f.write(f"in {format_duration(stats['time'])}\n\n")
# Group tests by status
tests_by_status = defaultdict(list)
for test in provider_tests[provider]:
tests_by_status[test["status"]].append(test)
tests_by_status[test['status']].append(test)
# Show failed tests first (if any)
if tests_by_status["FAILED"]:
if tests_by_status['FAILED']:
f.write("<details>\n<summary>Failed Tests</summary>\n\n")
for test in tests_by_status["FAILED"]:
for test in tests_by_status['FAILED']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
if test["message"]:
if test['message']:
# Truncate long error messages
msg = (
test["message"][:200] + "..."
if len(test["message"]) > 200
else test["message"]
)
msg = test['message'][:200] + '...' if len(test['message']) > 200 else test['message']
f.write(f" > {msg}\n")
f.write("\n</details>\n\n")
# Show errors (if any)
if tests_by_status["ERROR"]:
if tests_by_status['ERROR']:
f.write("<details>\n<summary>Error Tests</summary>\n\n")
for test in tests_by_status["ERROR"]:
for test in tests_by_status['ERROR']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
f.write("\n</details>\n\n")
# Show passed tests in collapsible section
if tests_by_status["PASSED"]:
if tests_by_status['PASSED']:
f.write("<details>\n<summary>Passed Tests</summary>\n\n")
for test in tests_by_status["PASSED"]:
for test in tests_by_status['PASSED']:
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
f.write("\n</details>\n\n")
# Show skipped tests (if any)
if tests_by_status["SKIPPED"]:
if tests_by_status['SKIPPED']:
f.write("<details>\n<summary>Skipped Tests</summary>\n\n")
for test in tests_by_status["SKIPPED"]:
for test in tests_by_status['SKIPPED']:
f.write(f"- `{test['name']}`\n")
f.write("\n</details>\n\n")
@ -341,43 +285,34 @@ def generate_markdown_report(
for provider in cross_provider:
if provider_tests[provider]:
stats = provider_stats[provider]
total = (
stats["passed"]
+ stats["failed"]
+ stats["errors"]
+ stats["skipped"]
)
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
f.write(f"#### {provider}\n\n")
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
f.write(
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n"
)
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n")
# For cross-provider tests, just show counts
f.write(f"- Passed: {stats['passed']}\n")
if stats["failed"] > 0:
if stats['failed'] > 0:
f.write(f"- Failed: {stats['failed']}\n")
if stats["errors"] > 0:
if stats['errors'] > 0:
f.write(f"- Errors: {stats['errors']}\n")
if stats["skipped"] > 0:
if stats['skipped'] > 0:
f.write(f"- Skipped: {stats['skipped']}\n")
f.write("\n")
print_colored(f"Report generated: {output_path}", Colors.GREEN)
except Exception as e:
print_colored(f"Error generating report: {e}", Colors.RED)
raise
def run_tests(
test_path: str = "tests/llm_translation/",
junit_xml: str = "test-results/junit.xml",
report_path: str = "test-results/llm_translation_report.md",
tag: str = None,
commit: str = None,
) -> int:
def run_tests(test_path: str = "tests/llm_translation/",
junit_xml: str = "test-results/junit.xml",
report_path: str = "test-results/llm_translation_report.md",
tag: str = None,
commit: str = None) -> int:
"""Run the LLM translation tests and generate report"""
# Create test results directory
@ -390,32 +325,21 @@ def run_tests(
# Run pytest
cmd = [
"uv",
"run",
"--no-sync",
"pytest",
test_path,
"uv", "run", "--no-sync", "pytest", test_path,
f"--junitxml={junit_xml}",
"-v",
"--tb=short",
"--maxfail=500",
"-n",
"auto",
"-n", "auto"
]
# Add timeout if pytest-timeout is installed
try:
subprocess.run(
["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
capture_output=True,
check=True,
)
subprocess.run(["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
capture_output=True, check=True)
cmd.extend(["--timeout=300"])
except:
print_colored(
"Warning: pytest-timeout not installed, skipping timeout option",
Colors.YELLOW,
)
print_colored("Warning: pytest-timeout not installed, skipping timeout option", Colors.YELLOW)
print_colored("Running pytest with command:", Colors.YELLOW)
print(f" {' '.join(cmd)}")
@ -438,15 +362,15 @@ def run_tests(
tree = ET.parse(junit_xml)
root = tree.getroot()
if root.tag == "testsuites":
suites = root.findall("testsuite")
if root.tag == 'testsuites':
suites = root.findall('testsuite')
else:
suites = [root]
total = sum(int(s.get("tests", 0)) for s in suites)
failures = sum(int(s.get("failures", 0)) for s in suites)
errors = sum(int(s.get("errors", 0)) for s in suites)
skipped = sum(int(s.get("skipped", 0)) for s in suites)
total = sum(int(s.get('tests', 0)) for s in suites)
failures = sum(int(s.get('failures', 0)) for s in suites)
errors = sum(int(s.get('errors', 0)) for s in suites)
skipped = sum(int(s.get('skipped', 0)) for s in suites)
passed = total - failures - errors - skipped
print(f" Total: {total}")
@ -460,11 +384,7 @@ def run_tests(
if total > 0:
pass_rate = (passed / total) * 100
color = (
Colors.GREEN
if pass_rate >= 80
else Colors.YELLOW if pass_rate >= 60 else Colors.RED
)
color = Colors.GREEN if pass_rate >= 80 else Colors.YELLOW if pass_rate >= 60 else Colors.RED
print_colored(f" Pass Rate: {pass_rate:.1f}%", color)
else:
print_colored("No test results found!", Colors.RED)
@ -474,24 +394,16 @@ def run_tests(
return result.returncode
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run LLM Translation Tests")
parser.add_argument(
"--test-path", default="tests/llm_translation/", help="Path to test directory"
)
parser.add_argument(
"--junit-xml",
default="test-results/junit.xml",
help="Path for JUnit XML output",
)
parser.add_argument(
"--report",
default="test-results/llm_translation_report.md",
help="Path for markdown report",
)
parser.add_argument("--test-path", default="tests/llm_translation/",
help="Path to test directory")
parser.add_argument("--junit-xml", default="test-results/junit.xml",
help="Path for JUnit XML output")
parser.add_argument("--report", default="test-results/llm_translation_report.md",
help="Path for markdown report")
parser.add_argument("--tag", help="Git tag or version")
parser.add_argument("--commit", help="Git commit SHA")
@ -500,9 +412,8 @@ if __name__ == "__main__":
# Get git info if not provided
if not args.commit:
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"], capture_output=True, text=True
)
result = subprocess.run(["git", "rev-parse", "HEAD"],
capture_output=True, text=True)
if result.returncode == 0:
args.commit = result.stdout.strip()
except:
@ -510,11 +421,8 @@ if __name__ == "__main__":
if not args.tag:
try:
result = subprocess.run(
["git", "describe", "--tags", "--abbrev=0"],
capture_output=True,
text=True,
)
result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"],
capture_output=True, text=True)
if result.returncode == 0:
args.tag = result.stdout.strip()
except:
@ -525,7 +433,7 @@ if __name__ == "__main__":
junit_xml=args.junit_xml,
report_path=args.report,
tag=args.tag,
commit=args.commit,
commit=args.commit
)
sys.exit(exit_code)

View File

@ -9,6 +9,7 @@ from pathlib import Path
import testing.postgresql
DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE)
DEFAULT_BASE_BRANCH = "litellm_internal_staging"

View File

@ -6,6 +6,7 @@ from tabulate import tabulate
from termcolor import colored
import os
# Define the list of models to benchmark
# select any LLM listed here: https://docs.litellm.ai/docs/providers
models = ["gpt-3.5-turbo", "claude-2"]

View File

@ -22,6 +22,7 @@ from fastapi import FastAPI, HTTPException, Header, Query
from fastapi.responses import JSONResponse
import uvicorn
app = FastAPI(
title="Braintrust Prompt Wrapper",
description="Wrapper server for Braintrust prompts to work with LiteLLM",

View File

@ -518,7 +518,8 @@ if __name__ == "__main__":
print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply")
print("=" * 80)
print("\nExample curl command:")
print(f"""
print(
f"""
curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
-H "Authorization: Bearer {bearer_token}" \\
-H "Content-Type: application/json" \\
@ -532,7 +533,8 @@ curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
}}
]
}}'
""")
"""
)
print("=" * 80)
uvicorn.run(app, host=host, port=port)

View File

@ -34,7 +34,8 @@ async def check_view_exists(): # noqa: PLR0915
print("LiteLLM_VerificationTokenView Exists!") # noqa
except Exception:
# If an error occurs, the view does not exist, so create it
await db.execute_raw("""
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
@ -44,7 +45,8 @@ async def check_view_exists(): # noqa: PLR0915
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
""")
"""
)
print("LiteLLM_VerificationTokenView Created!") # noqa

View File

@ -14,74 +14,53 @@ from litellm.types.utils import StandardCallbackDynamicParams
class EnterpriseCallbackControls:
@staticmethod
def is_callback_disabled_dynamically(
callback: litellm.CALLBACK_TYPES,
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams,
) -> bool:
"""
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
callback: litellm.CALLBACK_TYPES,
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams
) -> bool:
"""
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
Args:
callback: The callback to check (can be string, CustomLogger instance, or callable)
litellm_params: Parameters containing proxy server request info
Args:
callback: The callback to check (can be string, CustomLogger instance, or callable)
litellm_params: Parameters containing proxy server request info
Returns:
bool: True if the callback should be disabled, False otherwise
"""
from litellm.litellm_core_utils.custom_logger_registry import (
CustomLoggerRegistry,
)
try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(
litellm_params, standard_callback_dynamic_params
Returns:
bool: True if the callback should be disabled, False otherwise
"""
from litellm.litellm_core_utils.custom_logger_registry import (
CustomLoggerRegistry,
)
verbose_logger.debug(
f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}"
)
verbose_logger.debug(
f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}"
)
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if (
not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling()
):
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(
f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
)
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = (
CustomLoggerRegistry.get_callback_str_from_class_type(
callback.__class__
)
)
if (
callback_str is not None
and callback_str.lower() in disabled_callbacks
):
verbose_logger.debug(
f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
)
return True
return False
except Exception as e:
verbose_logger.debug(f"Error checking disabled callbacks header: {str(e)}")
return False
try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling():
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__)
if callback_str is not None and callback_str.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
return False
except Exception as e:
verbose_logger.debug(
f"Error checking disabled callbacks header: {str(e)}"
)
return False
@staticmethod
def get_disabled_callbacks(
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams,
) -> Optional[List[str]]:
def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
"""
Get the disabled callbacks from the standard callback dynamic params.
"""
@ -92,21 +71,15 @@ class EnterpriseCallbackControls:
request_headers = get_proxy_server_request_headers(litellm_params)
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
if disabled_callbacks is not None:
disabled_callbacks = set(
[cb.strip().lower() for cb in disabled_callbacks.split(",")]
)
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
return list(disabled_callbacks)
#########################################################
# check if disabled via request body
#########################################################
if (
standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
is not None
):
return standard_callback_dynamic_params.get(
"litellm_disabled_callbacks", None
)
if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
return None
@ -117,14 +90,10 @@ class EnterpriseCallbackControls:
# Check if admin has disabled this feature
if litellm.allow_dynamic_callback_disabling is not True:
verbose_logger.debug(
"Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling"
)
verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling")
return False
if premium_user:
return True
verbose_logger.warning(
f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}"
)
verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
return False

View File

@ -349,10 +349,8 @@ class BaseEmailLogger(CustomLogger):
)
# Calculate percentage and alert threshold
percentage = (
threshold_pct
if threshold_pct is not None
else int(EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100)
percentage = threshold_pct if threshold_pct is not None else int(
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100
)
threshold_fraction = percentage / 100.0
alert_threshold_str = (
@ -611,7 +609,9 @@ class BaseEmailLogger(CustomLogger):
continue
_id = user_info.token or user_info.user_id or "default_id"
_cache_key = f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
_cache_key = (
f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
)
result = await _cache.async_get_cache(key=_cache_key)
if result is not None:
@ -630,9 +630,7 @@ class BaseEmailLogger(CustomLogger):
continue
recipient_emails = list(set(emails))
event_message = (
f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
)
event_message = f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
webhook_event = WebhookEvent(
event="max_budget_alert",
event_message=event_message,

View File

@ -15,6 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
from .base_email import BaseEmailLogger
SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send"

View File

@ -1,7 +1,6 @@
"""
This is the litellm SMTP email integration
"""
import asyncio
from typing import List

View File

@ -1,7 +1,6 @@
"""
Enterprise specific logging utils
"""
from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata

View File

@ -153,11 +153,11 @@ async def get_audit_logs(
# Return paginated response
return PaginatedAuditLogResponse(
audit_logs=(
[AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs]
if audit_logs
else []
),
audit_logs=[
AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs
]
if audit_logs
else [],
total=total_count,
page=page,
page_size=page_size,

View File

@ -53,9 +53,7 @@ class CheckBatchCost:
"user_api_key_alias": getattr(user_row, "user_alias", None),
}
except Exception as e:
verbose_proxy_logger.error(
f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}"
)
verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}")
return {}
async def _cleanup_stale_managed_objects(self) -> None:
@ -64,22 +62,11 @@ class CheckBatchCost:
in non-terminal states as 'stale_expired'. These will never complete and
should not be polled.
"""
cutoff = datetime.now(timezone.utc) - timedelta(
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
)
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
result = await self.prisma_client.db.litellm_managedobjecttable.update_many(
where={
"file_purpose": "batch",
"status": {
"not_in": [
"completed",
"complete",
"failed",
"expired",
"cancelled",
"stale_expired",
]
},
"status": {"not_in": ["completed", "complete", "failed", "expired", "cancelled", "stale_expired"]},
"created_at": {"lt": cutoff},
},
data={"status": "stale_expired"},
@ -133,12 +120,9 @@ class CheckBatchCost:
try:
from litellm.integrations.prometheus import PrometheusLogger
prom_logger = PrometheusLogger.get_instance()
except Exception as e:
verbose_proxy_logger.error(
f"CheckBatchCost: could not get Prometheus logger: {e}"
)
verbose_proxy_logger.error(f"CheckBatchCost: could not get Prometheus logger: {e}")
prom_logger = None
processed_models: List[Tuple[Optional[str], Optional[str]]] = []
@ -177,11 +161,7 @@ class CheckBatchCost:
order={"created_at": "asc"},
)
except Exception as query_err:
if (
"batch_processed" not in str(query_err).lower()
and "unknown column" not in str(query_err).lower()
and "does not exist" not in str(query_err).lower()
):
if "batch_processed" not in str(query_err).lower() and "unknown column" not in str(query_err).lower() and "does not exist" not in str(query_err).lower():
raise
# Permanent schema gap — cache the result so future cycles skip straight to fallback
self._has_batch_processed_column = False
@ -236,13 +216,14 @@ class CheckBatchCost:
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
)
if prom_logger:
prom_logger.record_check_batch_cost_error(
"provider_retrieval_error"
)
prom_logger.record_check_batch_cost_error("provider_retrieval_error")
continue
## RETRIEVE THE BATCH JOB OUTPUT FILE
if response.status == "completed" and response.output_file_id is not None:
if (
response.status == "completed"
and response.output_file_id is not None
):
verbose_proxy_logger.info(
f"Batch ID: {batch_id} is complete, tracking cost and usage"
)
@ -269,25 +250,20 @@ class CheckBatchCost:
decoded = _is_base64_encoded_unified_file_id(raw_output_file_id)
if decoded:
try:
raw_output_file_id = decoded.split("llm_output_file_id,")[
1
].split(";")[0]
raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0]
except (IndexError, AttributeError):
pass
credentials = (
self.llm_router.get_deployment_credentials_with_provider(model_id)
or {}
)
credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {}
_file_content = await afile_content(
file_id=raw_output_file_id,
**credentials,
)
# Access content - handle both direct attribute and method call
if hasattr(_file_content, "content"):
if hasattr(_file_content, 'content'):
content_bytes = _file_content.content # type: ignore[union-attr]
elif hasattr(_file_content, "read"):
elif hasattr(_file_content, 'read'):
content_bytes = await _file_content.read() # type: ignore[misc]
else:
content_bytes = _file_content # type: ignore[assignment]
@ -314,9 +290,7 @@ class CheckBatchCost:
f"Skipping job {unified_object_id} because it is not a valid deployment info"
)
if prom_logger:
prom_logger.record_check_batch_cost_error(
"deployment_not_found"
)
prom_logger.record_check_batch_cost_error("deployment_not_found")
continue
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
litellm_model_name = deployment_info.litellm_params.model
@ -328,11 +302,7 @@ class CheckBatchCost:
# Pass deployment model_info so custom batch pricing
# (input_cost_per_token_batches etc.) is used for cost calc
deployment_model_info = (
deployment_info.model_info.model_dump()
if deployment_info.model_info
else {}
)
deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}
batch_cost, batch_usage, batch_models = (
await calculate_batch_cost_and_usage(
file_content_dictionary=file_content_as_dict,
@ -379,9 +349,7 @@ class CheckBatchCost:
# Record batch duration (completed_at - created_at)
if prom_logger and response.completed_at and response.created_at:
duration_seconds = float(
response.completed_at - response.created_at
)
duration_seconds = float(response.completed_at - response.created_at)
if duration_seconds >= 0:
prom_logger.record_managed_batch_duration(
duration_seconds=duration_seconds,
@ -390,9 +358,7 @@ class CheckBatchCost:
)
# Track this job for the final metrics summary
processed_models.append(
(model_name, str(llm_provider) if llm_provider else None)
)
processed_models.append((model_name, str(llm_provider) if llm_provider else None))
# mark the job as complete
try:

View File

@ -33,7 +33,9 @@ class CheckResponsesCost:
self.prisma_client: PrismaClient = prisma_client
self.llm_router: Router = llm_router
async def _expire_stale_rows(self, cutoff: datetime, batch_size: int) -> int:
async def _expire_stale_rows(
self, cutoff: datetime, batch_size: int
) -> int:
"""Execute the bounded UPDATE that marks stale rows as 'stale_expired'.
Isolated so it can be swapped / mocked in tests without touching the
@ -72,9 +74,7 @@ class CheckResponsesCost:
rows per invocation to avoid overwhelming the DB when there is a large
backlog.
"""
cutoff = datetime.now(timezone.utc) - timedelta(
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
)
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE)
if result > 0:
verbose_proxy_logger.warning(
@ -122,9 +122,7 @@ class CheckResponsesCost:
model_name = stored_response.get("model", None)
# Decrypt the response ID
responses_id_security, _, _ = (
ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
)
responses_id_security, _, _ = ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
# Prepare metadata with model information for cost tracking
litellm_metadata = {
@ -134,9 +132,7 @@ class CheckResponsesCost:
# Add model information if available
if model_name:
litellm_metadata["model"] = model_name
litellm_metadata["model_group"] = (
model_name # Use same value for model_group
)
litellm_metadata["model_group"] = model_name # Use same value for model_group
response = await litellm.aget_responses(
response_id=responses_id_security,
@ -175,3 +171,4 @@ class CheckResponsesCost:
verbose_proxy_logger.info(
f"Marked {len(completed_jobs)} response jobs as completed"
)

View File

@ -120,7 +120,9 @@ class _PROXY_LiteLLMManagedVectorStores(
VectorStoreCreateResponse from the provider
"""
# Use the router to create the vector store
response = await llm_router.avector_store_create(model=model, **request_data)
response = await llm_router.avector_store_create(
model=model, **request_data
)
return response
# ============================================================================
@ -274,7 +276,9 @@ class _PROXY_LiteLLMManagedVectorStores(
"""
vector_store_id = cast(Optional[str], data.get("vector_store_id"))
is_unified_id = (
is_base64_encoded_unified_id(vector_store_id) if vector_store_id else False
is_base64_encoded_unified_id(vector_store_id)
if vector_store_id
else False
)
if is_unified_id and vector_store_id:

View File

@ -2,6 +2,7 @@
Enterprise internal user management endpoints
"""
from fastapi import APIRouter, Depends, HTTPException
from litellm.proxy._types import UserAPIKeyAuth

View File

@ -39,23 +39,15 @@ class EmailEvent(str, enum.Enum):
soft_budget_crossed = "Soft Budget Crossed"
max_budget_alert = "Max Budget Alert"
class EmailEventSettings(BaseModel):
event: EmailEvent
enabled: bool
class EmailEventSettingsUpdateRequest(BaseModel):
settings: List[EmailEventSettings]
class EmailEventSettingsResponse(BaseModel):
settings: List[EmailEventSettings]
class DefaultEmailSettings(BaseModel):
"""Default settings for email events"""
settings: Dict[EmailEvent, bool] = Field(
default_factory=lambda: {
EmailEvent.virtual_key_created: True, # On by default
@ -65,11 +57,9 @@ class DefaultEmailSettings(BaseModel):
EmailEvent.max_budget_alert: True, # On by default
}
)
def to_dict(self) -> Dict[str, bool]:
"""Convert to dictionary with string keys for storage"""
return {event.value: enabled for event, enabled in self.settings.items()}
@classmethod
def get_defaults(cls) -> Dict[str, bool]:
"""Get the default settings as a dictionary with string keys"""

View File

@ -6,6 +6,7 @@ Always uses fastuuid for performance.
import fastuuid as _uuid # type: ignore
# Expose a module-like alias so callers can use: uuid.uuid4()
uuid = _uuid

View File

@ -9,6 +9,7 @@ from typing import Dict, Optional
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
# HTTP status code -> Anthropic error type
# Source: https://docs.anthropic.com/en/api/errors
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {

View File

@ -2,6 +2,7 @@
from typing_extensions import Literal, Required, TypedDict
# Known Anthropic error types
# Source: https://docs.anthropic.com/en/api/errors
AnthropicErrorType = Literal[

View File

@ -97,7 +97,7 @@ def _build_reasoning_item(
def _reasoning_item_to_response_input(
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]],
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]]
) -> Dict[str, Any]:
"""Convert a stored ChatCompletionReasoningItem back to a Responses API input item."""
r_input: Dict[str, Any] = {

View File

@ -5,6 +5,7 @@ Auto-detect content type per message: code, JSON, or text.
import json
import re
_CODE_KEYWORDS = re.compile(
r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b"
)

View File

@ -1,5 +1,6 @@
from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union
FileContentProvider = Literal[
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
]

View File

@ -18,7 +18,7 @@ else:
def process_slack_alerting_variables(
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]],
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
"""
process alert_to_webhook_url

View File

@ -9,6 +9,7 @@ import polars as pl
from .schema import FOCUS_NORMALIZED_SCHEMA
_TAG_KEYS = (
"team_id",
"team_alias",

View File

@ -105,7 +105,7 @@ def _remove_nulls(x: Dict[str, Any]) -> Dict[str, Any]:
def get_traces_and_spans_from_payload(
payload: List[Dict[str, Any]],
payload: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Separate traces and spans from payload.

View File

@ -994,8 +994,10 @@ class Logging(LiteLLMLoggingBaseClass):
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata["raw_request"] = "redacted by litellm. \
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
@ -1029,8 +1031,12 @@ class Logging(LiteLLMLoggingBaseClass):
error=str(e),
)
)
_metadata["raw_request"] = "Unable to Log \
raw request: {}".format(str(e))
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
)
)
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
try:
self.logger_fn(

View File

@ -5533,7 +5533,9 @@ def default_response_schema_prompt(response_schema: dict) -> str:
prompt_str = """Use this JSON schema:
```json
{}
```""".format(response_schema)
```""".format(
response_schema
)
return prompt_str

View File

@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, cast
from litellm._logging import verbose_logger
# ---------------------------------------------------------------------------
# SSE parsing helpers (module-level to keep the class lean)
# ---------------------------------------------------------------------------

View File

@ -16,6 +16,7 @@ from litellm.secret_managers.main import get_secret_str
from ...openai_like.chat.transformation import OpenAILikeChatConfig
BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1"

View File

@ -13,6 +13,7 @@ from typing import Tuple
import httpx
# ---------------------------------------------------------------------------
# Pre-built response templates
# ---------------------------------------------------------------------------

View File

@ -22,6 +22,7 @@ from litellm.types.utils import all_litellm_params
from ..common_utils import ElevenLabsException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import HttpxBinaryResponseContent

View File

@ -55,7 +55,7 @@ def _convert_image_to_gemini_format(image_file) -> Dict[str, str]:
def _usage_video_resolution_from_parameters(
parameters: Dict[str, Any],
parameters: Dict[str, Any]
) -> Optional[str]:
"""Normalize Veo ``parameters.resolution`` for usage and cost tracking."""
res = parameters.get("resolution")

View File

@ -201,7 +201,7 @@ class BaseOpenAILLM:
@staticmethod
def get_openai_client_initialization_param_fields(
client_type: Literal["openai", "azure"],
client_type: Literal["openai", "azure"]
) -> Tuple[str, ...]:
"""Returns a tuple of fields that are used to initialize the OpenAI client"""
if client_type == "openai":

View File

@ -49,6 +49,7 @@ from litellm.types.utils import (
)
from litellm.llms.openrouter.common_utils import OpenRouterException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:

View File

@ -207,7 +207,7 @@ def resolve_resource_group(sources: List[Source]) -> Optional[str]:
def _parse_service_key_once(
service_key: Optional[Union[str, dict]],
service_key: Optional[Union[str, dict]]
) -> Optional[Dict[str, Any]]:
"""
Pre-parse service_key if it's a string to avoid repeated JSON parsing.

View File

@ -14,6 +14,7 @@ from ...openai_like.chat.transformation import OpenAIGPTConfig
from ..utils import SnowflakeBaseConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

View File

@ -19,7 +19,7 @@ from ..gemini.transformation import (
def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]], # (idx, message)
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int:
"""
Find the array index that ends the first continuous sequence of message blocks.

View File

@ -632,14 +632,16 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
contents.append(ContentType(role="user", parts=tool_call_responses))
if len(contents) == 0:
verbose_logger.warning("""
verbose_logger.warning(
"""
No contents in messages. Contents are required. See
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
If the original request did not comply to OpenAI API requirements it should have failed by now,
but LiteLLM does not check for missing messages.
Setting an empty content to prevent an 400 error.
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
""")
"""
)
contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
return contents
except Exception as e:

View File

@ -139,7 +139,7 @@ class VertexTextToSpeechAPI(VertexLLM):
########## End of logging ############
####### Send the request ###################
if _is_async is True:
return self.async_audio_speech( # type: ignore
return self.async_audio_speech( # type:ignore
logging_obj=logging_obj, url=url, headers=headers, request=request
)
sync_handler = _get_httpx_client()

View File

@ -324,7 +324,7 @@ class CustomOpenAPISpec:
@staticmethod
def add_chat_completion_request_schema(
openapi_schema: Dict[str, Any],
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
@ -380,7 +380,7 @@ class CustomOpenAPISpec:
@staticmethod
def add_responses_api_request_schema(
openapi_schema: Dict[str, Any],
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
@ -410,7 +410,7 @@ class CustomOpenAPISpec:
@staticmethod
def add_llm_api_request_schema_body(
openapi_schema: Dict[str, Any],
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add LLM API request schema bodies to OpenAPI specification for documentation.

View File

@ -257,7 +257,7 @@ async def get_form_data(request: Request) -> Dict[str, Any]:
async def convert_upload_files_to_file_data(
form_data: Dict[str, Any],
form_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Convert FastAPI UploadFile objects to file data tuples for litellm.

View File

@ -34,7 +34,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
# If an error occurs, the view does not exist, so create it
await db.execute_raw("""
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
@ -46,7 +47,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
""")
"""
)
verbose_logger.debug("LiteLLM_VerificationTokenView Created!")

View File

@ -10,6 +10,7 @@ every text fragment.
from typing import Any, Callable, Dict, FrozenSet, Iterator, List
# Call types whose body carries free-form chat / prompt text that
# text-content guardrails (banned keywords, content moderation, secret
# detection, …) should inspect. The proxy ingress passes ``route_type``

View File

@ -4,6 +4,7 @@ from litellm.types.guardrails import SupportedGuardrailIntegrations
from .akto import AktoGuardrail
if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams

View File

@ -529,7 +529,7 @@ async def _update_existing_team_model_assignment(
"""
def _get_team_public_model_name(
model_info: Optional[Union[dict, str]],
model_info: Optional[Union[dict, str]]
) -> Optional[str]:
if isinstance(model_info, dict):
value = model_info.get("team_public_model_name")

View File

@ -4347,7 +4347,9 @@ async def list_team(
except Exception as e:
team_exception = """Invalid team object for team_id: {}. team_object={}.
Error: {}
""".format(team.team_id, team.model_dump(), str(e))
""".format(
team.team_id, team.model_dump(), str(e)
)
verbose_proxy_logger.exception(team_exception)
continue
# Sort the responses by team_alias

View File

@ -18,6 +18,7 @@ from litellm.litellm_core_utils.litellm_logging import (
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import StandardPassThroughResponseObject
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
"POST /v0/agents": "cursor:agent:create",
"GET /v0/agents": "cursor:agent:list",

View File

@ -292,7 +292,9 @@ class ProxyInitializationHelpers:
_endpoint_str = (
f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\"
)
curl_command = _endpoint_str + """
curl_command = (
_endpoint_str
+ """
--header 'Content-Type: application/json' \\
--data ' {
"model": "gpt-3.5-turbo",
@ -305,6 +307,7 @@ class ProxyInitializationHelpers:
}'
\n
"""
)
print() # noqa
print( # noqa
'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n'
@ -380,9 +383,11 @@ class ProxyInitializationHelpers:
with open(os.devnull, "w") as devnull:
subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e:
print(f"""
print( # noqa
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") # noqa # noqa
"""
) # noqa
@staticmethod
def _is_port_in_use(port):

View File

@ -2688,9 +2688,11 @@ def run_ollama_serve():
with open(os.devnull, "w") as devnull:
subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e:
verbose_proxy_logger.debug(f"""
verbose_proxy_logger.debug(
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""")
"""
)
def _get_process_rss_mb() -> Optional[float]:

View File

@ -3184,14 +3184,16 @@ async def provider_budgets() -> ProviderBudgetResponse:
async def get_spend_by_tags(
prisma_client: PrismaClient, start_date=None, end_date=None
):
response = await prisma_client.db.query_raw("""
response = await prisma_client.db.query_raw(
"""
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
GROUP BY individual_request_tag;
""")
"""
)
return response

View File

@ -2712,7 +2712,8 @@ class PrismaClient:
required_view = "LiteLLM_VerificationTokenView"
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
ret = await self.db.query_raw(f"""
ret = await self.db.query_raw(
f"""
WITH existing_views AS (
SELECT viewname
FROM pg_views
@ -2724,7 +2725,8 @@ class PrismaClient:
(SELECT COUNT(*) FROM existing_views) AS view_count,
ARRAY_AGG(viewname) AS view_names
FROM existing_views
""")
"""
)
expected_total_views = len(expected_views)
if ret[0]["view_count"] == expected_total_views:
verbose_proxy_logger.info("All necessary views exist!")
@ -2733,7 +2735,8 @@ class PrismaClient:
## check if required view exists ##
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
await self.health_check() # make sure we can connect to db
await self.db.execute_raw("""
await self.db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
@ -2743,7 +2746,8 @@ class PrismaClient:
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
""")
"""
)
verbose_proxy_logger.info(
"LiteLLM_VerificationTokenView Created in DB!"

View File

@ -826,7 +826,7 @@ class Router:
@staticmethod
def _normalize_strategy(
strategy: Union[RoutingStrategy, str, None],
strategy: Union[RoutingStrategy, str, None]
) -> Optional[str]:
if strategy is None:
return None

View File

@ -103,7 +103,7 @@ def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str
def _recent_tool_results(
messages: Optional[List[Dict[str, Any]]],
messages: Optional[List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""Extract the current turn's tool result payloads from the request messages.

View File

@ -24,6 +24,7 @@ from litellm.router_strategy.adaptive_router.config import (
TOOL_CALL_HISTORY_MAX,
)
# ---- Public types ---------------------------------------------------------

View File

@ -10,11 +10,11 @@ This means you can use this with weighted-pick, lowest-latency, simple-shuffle,
Example:
```
openai:
budget_limit: 0.000000000001
time_period: 1d
budget_limit: 0.000000000001
time_period: 1d
anthropic:
budget_limit: 100
time_period: 7d
budget_limit: 100
time_period: 7d
```
"""

View File

@ -34,7 +34,7 @@ class PatternUtils:
@staticmethod
def sorted_patterns(
patterns: Dict[str, List[Dict]],
patterns: Dict[str, List[Dict]]
) -> List[Tuple[str, List[Dict]]]:
"""
Cached property for patterns sorted by specificity.

View File

@ -21,7 +21,7 @@ class VectorStoreFileRequestUtils:
@staticmethod
def get_create_request_params(
params: Dict[str, Any],
params: Dict[str, Any]
) -> VectorStoreFileCreateRequest:
filtered = VectorStoreFileRequestUtils._filter_params(
params=params, model=VectorStoreFileCreateRequest
@ -37,7 +37,7 @@ class VectorStoreFileRequestUtils:
@staticmethod
def get_update_request_params(
params: Dict[str, Any],
params: Dict[str, Any]
) -> VectorStoreFileUpdateRequest:
filtered = VectorStoreFileRequestUtils._filter_params(
params=params, model=VectorStoreFileUpdateRequest

View File

@ -76,11 +76,7 @@ utils = [
# Not in Docker or PyPI proxy extra.
"numpydoc==1.8.0",
]
# diskcache intentionally unpinned: CVE-2025-69872 (pickle RCE) has no
# upstream fix. Stub kept so `pip install litellm[caching]` doesn't warn;
# DiskCache loads diskcache lazily and forces JSONDisk for safety. See
# litellm/caching/disk_cache.py.
caching = []
caching = ["diskcache==5.6.3"]
semantic-router = [
"semantic-router==0.1.12; python_version < '3.14'",
"aurelio-sdk==0.0.19; python_version < '3.14'",
@ -130,7 +126,7 @@ litellm-proxy = "litellm.proxy.client.cli:cli"
dev = [
"diff-cover==9.7.2",
"flake8==7.3.0",
"black==26.3.1",
"black==24.10.0",
"mypy==1.19.0",
"pytest==9.0.3",
"pytest-mock==3.15.1",

View File

@ -35,7 +35,7 @@ import httpx
class EvalCase:
category: str
prompt: str
ideal: str # criteria the judge checks the response against
ideal: str # criteria the judge checks the response against
EVAL_CASES: List[EvalCase] = [
@ -177,19 +177,14 @@ async def evaluate(
async with httpx.AsyncClient() as client:
for i, case in enumerate(EVAL_CASES, 1):
print(f"\n[{i}/{len(EVAL_CASES)}] category={case.category}")
print(
f" prompt : {case.prompt[:80]}{'' if len(case.prompt) > 80 else ''}"
)
print(f" prompt : {case.prompt[:80]}{'' if len(case.prompt) > 80 else ''}")
session_id = f"eval-{uuid.uuid4()}"
# Round 1: single-turn real request — get the actual LLM response to judge.
try:
response, chosen = await _chat(
client,
proxy_url,
api_key,
router,
client, proxy_url, api_key, router,
[{"role": "user", "content": case.prompt}],
session_id=session_id,
)
@ -199,25 +194,16 @@ async def evaluate(
continue
print(f" model : {chosen or router}")
print(
f" response : {response[:120].replace(chr(10), ' ')}{'' if len(response) > 120 else ''}"
)
print(f" response : {response[:120].replace(chr(10), ' ')}{'' if len(response) > 120 else ''}")
# Judge the real response.
judge_msgs = [
{"role": "system", "content": JUDGE_SYSTEM},
{
"role": "user",
"content": _judge_user(case.prompt, case.ideal, response),
},
{"role": "user", "content": _judge_user(case.prompt, case.ideal, response)},
]
try:
verdict, _ = await _chat(
client,
proxy_url,
api_key,
judge_model,
judge_msgs,
client, proxy_url, api_key, judge_model, judge_msgs,
)
except Exception as exc: # noqa: BLE001
print(f" ERROR calling judge: {exc}", file=sys.stderr)
@ -241,19 +227,15 @@ async def evaluate(
# On PASS → satisfaction follow-up (+alpha). On FAIL → neutral (no signal).
follow_up = SATISFY_FOLLOWUP if is_pass else NEUTRAL_FOLLOWUP
bandit_msgs = [
{"role": "user", "content": case.prompt},
{"role": "user", "content": case.prompt},
{"role": "assistant", "content": response},
{"role": "user", "content": "ok continue"},
{"role": "user", "content": "ok continue"},
{"role": "assistant", "content": FAB_ASSISTANT},
{"role": "user", "content": follow_up},
{"role": "user", "content": follow_up},
]
try:
await _chat(
client,
proxy_url,
api_key,
router,
bandit_msgs,
client, proxy_url, api_key, router, bandit_msgs,
session_id=session_id,
)
except Exception as exc: # noqa: BLE001
@ -275,17 +257,11 @@ async def evaluate(
# Entry point
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(
description="Evaluate the adaptive router with LLM-as-judge."
)
ap.add_argument("--proxy-url", default="http://localhost:4000")
ap.add_argument("--api-key", required=True, help="proxy API key")
ap.add_argument(
"--router", default="smart-cheap-router", help="adaptive router model name"
)
ap.add_argument(
"--judge-model", default="smart", help="model name for the judge (via proxy)"
)
ap = argparse.ArgumentParser(description="Evaluate the adaptive router with LLM-as-judge.")
ap.add_argument("--proxy-url", default="http://localhost:4000")
ap.add_argument("--api-key", required=True, help="proxy API key")
ap.add_argument("--router", default="smart-cheap-router", help="adaptive router model name")
ap.add_argument("--judge-model", default="smart", help="model name for the judge (via proxy)")
args = ap.parse_args()
asyncio.run(evaluate(args.proxy_url, args.api_key, args.router, args.judge_model))

View File

@ -72,8 +72,8 @@ PROMPTS: Dict[str, List[str]] = {
# so that signals attribute to the right (type, model) bandit cell.
SATISFY: Dict[str, str] = {
"code_generation": "thanks, that works! now write me a python function that does the inverse",
"factual_lookup": "perfect, thanks! who is the current prime minister?",
"writing": "great, thanks! now write a follow-up email confirming attendance",
"factual_lookup": "perfect, thanks! who is the current prime minister?",
"writing": "great, thanks! now write a follow-up email confirming attendance",
}
# Neutral follow-up — does not match any signal regex, does not move the bandit.
@ -83,8 +83,8 @@ NEUTRAL_FOLLOWUP = "ok, noted"
# Defaults: smart dominates code/writing; both are fine for factual_lookup.
ORACLE: Dict[str, Dict[str, float]] = {
"code_generation": {"smart": 0.92, "fast": 0.35},
"factual_lookup": {"smart": 0.90, "fast": 0.85},
"writing": {"smart": 0.85, "fast": 0.55},
"factual_lookup": {"smart": 0.90, "fast": 0.85},
"writing": {"smart": 0.85, "fast": 0.55},
}
# Fabricated assistant turn — content doesn't matter for the hook, only the role.
@ -94,11 +94,11 @@ FAB_ASSISTANT = "Got it. Working on that now."
def _build_messages(prompt: str, last_user: str) -> List[Dict[str, str]]:
"""5-message conversation that passes the SIGNAL_GATE_MIN_MESSAGES=4 gate."""
return [
{"role": "user", "content": prompt},
{"role": "user", "content": prompt},
{"role": "assistant", "content": FAB_ASSISTANT},
{"role": "user", "content": "ok continue"},
{"role": "user", "content": "ok continue"},
{"role": "assistant", "content": FAB_ASSISTANT},
{"role": "user", "content": last_user},
{"role": "user", "content": last_user},
]
@ -155,11 +155,7 @@ async def _drive_one_session(
#
# Round 1: neutral follow-up → no signal fires, but we learn the pick.
ok, chosen = await _send(
client,
proxy_url,
api_key,
router,
session_id,
client, proxy_url, api_key, router, session_id,
_build_messages(prompt, NEUTRAL_FOLLOWUP),
mock_response=FAB_ASSISTANT,
)
@ -175,15 +171,10 @@ async def _drive_one_session(
# follow-up matches satisfaction → +alpha for (request_type, chosen).
history = _build_messages(prompt, NEUTRAL_FOLLOWUP) + [
{"role": "assistant", "content": FAB_ASSISTANT},
{"role": "user", "content": follow_up},
{"role": "user", "content": follow_up},
]
await _send(
client,
proxy_url,
api_key,
router,
session_id,
history,
client, proxy_url, api_key, router, session_id, history,
mock_response=FAB_ASSISTANT,
)
return chosen
@ -192,22 +183,13 @@ async def _drive_one_session(
async def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--proxy-url", default="http://localhost:4000")
ap.add_argument(
"--api-key", required=True, help="proxy key with /v1/chat/completions perms"
)
ap.add_argument("--router", default="smart-cheap-router")
ap.add_argument("--rounds", type=int, default=100)
ap.add_argument(
"--rate",
type=float,
default=0.5,
help="seconds between sessions; lower = faster",
)
ap.add_argument(
"--types",
default="code_generation,factual_lookup,writing",
help="comma-separated subset of request types to drive",
)
ap.add_argument("--api-key", required=True, help="proxy key with /v1/chat/completions perms")
ap.add_argument("--router", default="smart-cheap-router")
ap.add_argument("--rounds", type=int, default=100)
ap.add_argument("--rate", type=float, default=0.5,
help="seconds between sessions; lower = faster")
ap.add_argument("--types", default="code_generation,factual_lookup,writing",
help="comma-separated subset of request types to drive")
args = ap.parse_args()
types = [t.strip() for t in args.types.split(",") if t.strip() in PROMPTS]
@ -225,12 +207,7 @@ async def main() -> None:
rt = random.choice(types)
prompt = random.choice(PROMPTS[rt])
chosen = await _drive_one_session(
client,
args.proxy_url,
args.api_key,
args.router,
rt,
prompt,
client, args.proxy_url, args.api_key, args.router, rt, prompt,
)
if chosen:
counts[(rt, chosen)] = counts.get((rt, chosen), 0) + 1

View File

@ -8,6 +8,7 @@ import statistics
import aiohttp
REQUEST_BODY = {
"model": "db-openai-endpoint",
"messages": [{"role": "user", "content": "hi"}],

View File

@ -42,7 +42,8 @@ from litellm.types.utils import CallTypes
PROBLEMS = [
{
"id": "has_close_elements",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def has_close_elements(numbers: List[float], threshold: float) -> bool:
@ -53,8 +54,10 @@ PROBLEMS = [
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
True
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True
assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False
assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True
@ -62,11 +65,13 @@ PROBLEMS = [
assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0], 2.0) == True
assert has_close_elements([], 0.5) == False
print("PASSED")
"""),
"""
),
},
{
"id": "separate_paren_groups",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def separate_paren_groups(paren_string: str) -> List[str]:
@ -77,18 +82,22 @@ PROBLEMS = [
>>> separate_paren_groups('( ) (( )) (( )( ))')
['()', '(())', '(()())']
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert separate_paren_groups('(()()) ((())) () ((())()())') == ['(()())', '((()))', '()', '((())()())']
assert separate_paren_groups('() (()) ((())) (((())))') == ['()', '(())', '((()))', '(((())))']
assert separate_paren_groups('(()(()))') == ['(()(()))']
assert separate_paren_groups('( ) (( )) (( )( ))') == ['()', '(())', '(()())']
print("PASSED")
"""),
"""
),
},
{
"id": "truncate_number",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
def truncate_number(number: float) -> float:
\"\"\"Given a positive floating point number, it can be decomposed into
an integer part (largest integer smaller than given number) and decimals
@ -97,17 +106,21 @@ PROBLEMS = [
>>> truncate_number(3.5)
0.5
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert truncate_number(3.5) == 0.5
assert abs(truncate_number(1.33) - 0.33) < 1e-6
assert abs(truncate_number(123.456) - 0.456) < 1e-6
print("PASSED")
"""),
"""
),
},
{
"id": "below_zero",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def below_zero(operations: List[int]) -> bool:
@ -119,8 +132,10 @@ PROBLEMS = [
>>> below_zero([1, 2, -4, 5])
True
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert below_zero([]) == False
assert below_zero([1, 2, -3, 1, 2, -3]) == False
assert below_zero([1, 2, -4, 5, 6]) == True
@ -128,11 +143,13 @@ PROBLEMS = [
assert below_zero([1, -1, 2, -2, 5, -5, 4, -5]) == True
assert below_zero([1, -2]) == True
print("PASSED")
"""),
"""
),
},
{
"id": "mean_absolute_deviation",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def mean_absolute_deviation(numbers: List[float]) -> float:
@ -144,17 +161,21 @@ PROBLEMS = [
>>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0])
1.0
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0]) - 1.0) < 1e-6
assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0, 5.0]) - 1.2) < 1e-6
assert abs(mean_absolute_deviation([1.0, 1.0, 1.0, 1.0]) - 0.0) < 1e-6
print("PASSED")
"""),
"""
),
},
{
"id": "intersperse",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def intersperse(numbers: List[int], delimiter: int) -> List[int]:
@ -164,17 +185,21 @@ PROBLEMS = [
>>> intersperse([1, 2, 3], 4)
[1, 4, 2, 4, 3]
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert intersperse([], 7) == []
assert intersperse([5, 6, 3, 2], 8) == [5, 8, 6, 8, 3, 8, 2]
assert intersperse([2, 2, 2], 2) == [2, 2, 2, 2, 2]
print("PASSED")
"""),
"""
),
},
{
"id": "parse_nested_parens",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def parse_nested_parens(paren_string: str) -> List[int]:
@ -184,17 +209,21 @@ PROBLEMS = [
>>> parse_nested_parens('(()()) ((())) () ((())())')
[2, 3, 1, 3]
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert parse_nested_parens('(()()) ((())) () ((())())') == [2, 3, 1, 3]
assert parse_nested_parens('() (()) ((())) (((())))') == [1, 2, 3, 4]
assert parse_nested_parens('(()(())((())))') == [4]
print("PASSED")
"""),
"""
),
},
{
"id": "filter_by_substring",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def filter_by_substring(strings: List[str], substring: str) -> List[str]:
@ -204,18 +233,22 @@ PROBLEMS = [
>>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')
['abc', 'bacd', 'array']
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert filter_by_substring([], 'john') == []
assert filter_by_substring(['xxx', 'asd', 'xxy', 'john doe', 'xxxuj', 'xxx'], 'xxx') == ['xxx', 'xxxuj', 'xxx']
assert filter_by_substring(['xxx', 'asd', 'aaber', 'john doe', 'xxxuj', 'xxx'], 'xx') == ['xxx', 'xxxuj', 'xxx']
assert filter_by_substring(['grunt', 'hierarchial', 'abc', 'hierarchial'], 'hi') == ['hierarchial', 'hierarchial']
print("PASSED")
"""),
"""
),
},
{
"id": "sum_product",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List, Tuple
def sum_product(numbers: List[int]) -> Tuple[int, int]:
@ -226,19 +259,23 @@ PROBLEMS = [
>>> sum_product([1, 2, 3, 4])
(10, 24)
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert sum_product([]) == (0, 1)
assert sum_product([1, 1, 1]) == (3, 1)
assert sum_product([100, 0]) == (100, 0)
assert sum_product([3, 5, 7]) == (15, 105)
assert sum_product([10]) == (10, 10)
print("PASSED")
"""),
"""
),
},
{
"id": "max_element",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def max_element(l: List[int]) -> int:
@ -248,17 +285,21 @@ PROBLEMS = [
>>> max_element([5, 3, -5, 2, -3, 3, 9, 0, 123, 1, -10])
123
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert max_element([1, 2, 3]) == 3
assert max_element([5, 3, -5, 2, -3, 3, 9, 0, 124, 1, -10]) == 124
assert max_element([-1, -2, -3]) == -1
print("PASSED")
"""),
"""
),
},
{
"id": "fizz_buzz",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
def fizz_buzz(n: int) -> int:
\"\"\"Return the number of times the digit 7 appears in integers less than n which are divisible by 11 or 13.
>>> fizz_buzz(50)
@ -268,8 +309,10 @@ PROBLEMS = [
>>> fizz_buzz(79)
3
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert fizz_buzz(50) == 0
assert fizz_buzz(78) == 2
assert fizz_buzz(79) == 3
@ -277,11 +320,13 @@ PROBLEMS = [
assert fizz_buzz(200) == 6
assert fizz_buzz(4000) == 192
print("PASSED")
"""),
"""
),
},
{
"id": "sort_by_binary_len",
"prompt": textwrap.dedent("""\
"prompt": textwrap.dedent(
"""\
from typing import List
def sort_array(arr: List[int]) -> List[int]:
@ -294,8 +339,10 @@ PROBLEMS = [
>>> sort_array([1, 0, 2, 3, 4])
[0, 1, 2, 4, 3]
\"\"\"
"""),
"tests": textwrap.dedent("""\
"""
),
"tests": textwrap.dedent(
"""\
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 4, 3, 5]
assert sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
assert sort_array([1, 0, 2, 3, 4]) == [0, 1, 2, 4, 3]
@ -303,7 +350,8 @@ PROBLEMS = [
assert sort_array([2, 5, 77, 4, 5, 3, 5, 7, 2, 3, 4]) == [2, 2, 4, 4, 3, 3, 5, 5, 5, 7, 77]
assert sort_array([3, 6, 44, 12, 32, 5]) == [32, 3, 5, 6, 12, 44]
print("PASSED")
"""),
"""
),
},
]
@ -312,7 +360,8 @@ PROBLEMS = [
# compressor to identify and drop them.
DISTRACTOR_SNIPPETS = [
# distractor 0 — database connection pool
textwrap.dedent("""\
textwrap.dedent(
"""\
# db_pool.py
import threading
from contextlib import contextmanager
@ -359,9 +408,11 @@ DISTRACTOR_SNIPPETS = [
for conn in self._pool:
conn.close()
self._pool.clear()
"""),
"""
),
# distractor 1 — HTTP retry logic
textwrap.dedent("""\
textwrap.dedent(
"""\
# http_retry.py
import time
import random
@ -405,9 +456,11 @@ DISTRACTOR_SNIPPETS = [
resp = requests.get(url, params=params, timeout=30)
resp.raise_for_status()
return resp.json()
"""),
"""
),
# distractor 2 — LRU cache implementation
textwrap.dedent("""\
textwrap.dedent(
"""\
# lru_cache.py
from collections import OrderedDict
from threading import RLock
@ -456,9 +509,11 @@ DISTRACTOR_SNIPPETS = [
def __contains__(self, key):
return key in self._cache
"""),
"""
),
# distractor 3 — CSV report generator
textwrap.dedent("""\
textwrap.dedent(
"""\
# report_gen.py
import csv
import io
@ -511,9 +566,11 @@ DISTRACTOR_SNIPPETS = [
except (ValueError, KeyError):
return False
return self.filter_rows(in_range)
"""),
"""
),
# distractor 4 — async task queue
textwrap.dedent("""\
textwrap.dedent(
"""\
# task_queue.py
import asyncio
import logging
@ -583,9 +640,11 @@ DISTRACTOR_SNIPPETS = [
async def shutdown(self):
for w in self._workers:
w.cancel()
"""),
"""
),
# distractor 5 — config parser with env var interpolation
textwrap.dedent("""\
textwrap.dedent(
"""\
# config_parser.py
import os
import re
@ -648,7 +707,8 @@ DISTRACTOR_SNIPPETS = [
if val is None:
raise ConfigError(f"Required config key missing: {key}")
return val
"""),
"""
),
]

View File

@ -18,6 +18,7 @@ from litellm.batches.batch_utils import (
from litellm.cost_calculator import batch_cost_calculator
from litellm.types.utils import Usage
# --- helpers ---

View File

@ -20,6 +20,7 @@ sys.path.insert(0, os.path.abspath("../.."))
import litellm
SERVER_URL = "https://exampleopenaiendpoint-production-0ee2.up.railway.app/v1"

View File

@ -12,6 +12,7 @@ import litellm
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.token_counter import token_counter
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

View File

@ -2,6 +2,7 @@ import ast
import os
from typing import List, Dict, Any
ALLOWED_FILE = os.path.normpath("litellm/_uuid.py")

View File

@ -23,6 +23,8 @@ def event_loop():
loop.close()
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""

View File

@ -957,18 +957,14 @@ async def test_langfuse_callback_failure_metric(prometheus_logger):
# Get initial value
initial_value = 0
try:
initial_value = (
prometheus_logger.litellm_callback_logging_failures_metric.labels(
callback_name="langfuse"
)._value.get()
)
initial_value = prometheus_logger.litellm_callback_logging_failures_metric.labels(
callback_name="langfuse"
)._value.get()
except Exception:
initial_value = 0
# Create Langfuse logger with mocked initialization
with patch(
"litellm.integrations.langfuse.langfuse_prompt_management.langfuse_client_init"
):
with patch("litellm.integrations.langfuse.langfuse_prompt_management.langfuse_client_init"):
langfuse_logger = LangfusePromptManagement()
# Mock the log_event_on_langfuse to raise an exception
@ -980,12 +976,10 @@ async def test_langfuse_callback_failure_metric(prometheus_logger):
mock_get_logger.return_value = mock_logger
# Mock handle_callback_failure to track calls
with patch.object(
prometheus_logger, "increment_callback_logging_failure"
) as mock_increment:
with patch.object(prometheus_logger, "increment_callback_logging_failure") as mock_increment:
# Inject prometheus logger into the langfuse logger
langfuse_logger.handle_callback_failure = (
lambda callback_name: mock_increment(callback_name=callback_name)
langfuse_logger.handle_callback_failure = lambda callback_name: mock_increment(
callback_name=callback_name
)
# Call async_log_success_event - should catch exception and increment metric
@ -1017,51 +1011,43 @@ async def test_langfuse_otel_callback_failure_metric(prometheus_logger):
# Get initial value
initial_value = 0
try:
initial_value = (
prometheus_logger.litellm_callback_logging_failures_metric.labels(
callback_name="langfuse_otel"
)._value.get()
)
initial_value = prometheus_logger.litellm_callback_logging_failures_metric.labels(
callback_name="langfuse_otel"
)._value.get()
except Exception:
initial_value = 0
# Create Langfuse OTEL logger with mocked initialization
with patch(
"litellm.integrations.opentelemetry.OpenTelemetry.__init__", return_value=None
):
with patch("litellm.integrations.opentelemetry.OpenTelemetry.__init__", return_value=None):
langfuse_otel_logger = LangfuseOtelLogger(callback_name="langfuse_otel")
langfuse_otel_logger.callback_name = "langfuse_otel"
# Mock handle_callback_failure to track calls
with patch.object(
prometheus_logger, "increment_callback_logging_failure"
) as mock_increment:
with patch.object(prometheus_logger, "increment_callback_logging_failure") as mock_increment:
# Inject prometheus logger into the langfuse otel logger
langfuse_otel_logger.handle_callback_failure = (
lambda callback_name: mock_increment(callback_name=callback_name)
langfuse_otel_logger.handle_callback_failure = lambda callback_name: mock_increment(
callback_name=callback_name
)
# Test that the OpenTelemetry base class set_attributes exception handler works
# This is where langfuse_otel failures are caught and tracked
with patch.object(
langfuse_otel_logger, "set_attributes"
) as mock_set_attributes:
with patch.object(langfuse_otel_logger, "set_attributes") as mock_set_attributes:
# Simulate the exception handling in set_attributes
def set_attributes_with_error(*args, **kwargs):
# This simulates what happens in the real set_attributes method
try:
raise Exception("Attribute error")
except Exception as e:
langfuse_otel_logger.handle_callback_failure(
callback_name=langfuse_otel_logger.callback_name
)
langfuse_otel_logger.handle_callback_failure(callback_name=langfuse_otel_logger.callback_name)
mock_set_attributes.side_effect = set_attributes_with_error
# Call set_attributes
try:
langfuse_otel_logger.set_attributes(
span=MagicMock(), kwargs={}, response_obj={}
span=MagicMock(),
kwargs={},
response_obj={}
)
except Exception:
pass

View File

@ -49,13 +49,10 @@ async def test_enterprise_custom_auth_returns_string():
mock_user_auth = AsyncMock(return_value="sk-test-key")
request = MagicMock(spec=Request)
with (
patch(
"litellm.proxy.auth.user_api_key_auth.enterprise_custom_auth",
mock_user_auth,
),
patch("litellm.proxy.proxy_server.master_key", "sk-1234"),
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
with patch(
"litellm.proxy.auth.user_api_key_auth.enterprise_custom_auth", mock_user_auth
), patch("litellm.proxy.proxy_server.master_key", "sk-1234"), patch(
"litellm.proxy.proxy_server.prisma_client", MagicMock()
):
# Verify the key is correctly handled in _user_api_key_auth_builder
with patch(

View File

@ -105,9 +105,7 @@ async def test_apply_guardrail_endpoint_with_presidio_guardrail():
mock_guardrail = Mock(spec=CustomGuardrail)
# Simulate masking PII entities - returns GenericGuardrailAPIInputs (dict with texts key)
mock_guardrail.apply_guardrail = AsyncMock(
return_value={
"texts": ["My name is [PERSON] and my email is [EMAIL_ADDRESS]"]
}
return_value={"texts": ["My name is [PERSON] and my email is [EMAIL_ADDRESS]"]}
)
# Configure the registry to return our mock guardrail

View File

@ -95,9 +95,9 @@ async def test_async_pre_call_deployment_hook_resolves_model_id_from_litellm_met
kwargs=kwargs, call_type=CallTypes.acreate_batch
)
assert (
result["input_file_id"] == provider_file_id
), f"Expected provider file ID '{provider_file_id}', got '{result['input_file_id']}'"
assert result["input_file_id"] == provider_file_id, (
f"Expected provider file ID '{provider_file_id}', got '{result['input_file_id']}'"
)
@pytest.mark.asyncio
@ -134,9 +134,9 @@ async def test_async_pre_call_deployment_hook_prefers_top_level_model_info():
kwargs=kwargs, call_type=CallTypes.acreate_batch
)
assert (
result["input_file_id"] == top_level_provider_file
), "Should prefer top-level model_info over litellm_metadata"
assert result["input_file_id"] == top_level_provider_file, (
"Should prefer top-level model_info over litellm_metadata"
)
@pytest.mark.asyncio
@ -162,9 +162,9 @@ async def test_async_pre_call_deployment_hook_no_model_info_leaves_file_id_uncha
kwargs=kwargs, call_type=CallTypes.acreate_batch
)
assert (
result["input_file_id"] == managed_file_id
), "File ID should remain unchanged when model_info is not available"
assert result["input_file_id"] == managed_file_id, (
"File ID should remain unchanged when model_info is not available"
)
# def test_list_managed_files():
@ -341,9 +341,7 @@ async def test_async_pre_call_hook_for_unified_finetuning_job():
@pytest.mark.asyncio
@pytest.mark.parametrize(
"call_type", ["afile_content", "afile_delete", "afile_retrieve"]
)
@pytest.mark.parametrize("call_type", ["afile_content", "afile_delete", "afile_retrieve"])
async def test_can_user_call_unified_file_id(call_type):
"""
Test that on file retrieve, delete, and content we check if the user has access to the file
@ -624,7 +622,8 @@ async def test_error_file_id_for_failed_batch():
mock_retrieve.return_value = error_file_object
user_api_key_dict = UserAPIKeyAuth(
user_id="test-user-123", parent_otel_span=MagicMock()
user_id="test-user-123",
parent_otel_span=MagicMock()
)
response = await proxy_managed_files.async_post_call_success_hook(
@ -637,9 +636,7 @@ async def test_error_file_id_for_failed_batch():
assert cast(LiteLLMBatch, response).error_file_id is not None
assert not cast(LiteLLMBatch, response).error_file_id.startswith("error-")
# Verify it's a base64 encoded managed file ID
assert _is_base64_encoded_unified_file_id(
cast(LiteLLMBatch, response).error_file_id
)
assert _is_base64_encoded_unified_file_id(cast(LiteLLMBatch, response).error_file_id)
@pytest.mark.asyncio
@ -681,10 +678,8 @@ async def test_async_post_call_success_hook_twice_assert_no_unique_violation():
# first retrieve batch
tasks = []
first_create_task = asyncio.create_task
with patch("asyncio.create_task") as mock_create_task:
mock_create_task.side_effect = (
lambda coro: tasks.append(first_create_task(coro)) or tasks[-1]
)
with patch('asyncio.create_task') as mock_create_task:
mock_create_task.side_effect = lambda coro: tasks.append(first_create_task(coro)) or tasks[-1]
response = await proxy_managed_files.async_post_call_success_hook(
data={},
@ -705,10 +700,8 @@ async def test_async_post_call_success_hook_twice_assert_no_unique_violation():
# second retrieve batch
tasks = []
second_create_task = asyncio.create_task
with patch("asyncio.create_task") as mock_create_task:
mock_create_task.side_effect = (
lambda coro: tasks.append(second_create_task(coro)) or tasks[-1]
)
with patch('asyncio.create_task') as mock_create_task:
mock_create_task.side_effect = lambda coro: tasks.append(second_create_task(coro)) or tasks[-1]
await proxy_managed_files.async_post_call_success_hook(
data={},
@ -760,10 +753,7 @@ def test_update_responses_input_with_unified_file_id():
assert updated_input[0]["content"][0]["type"] == "input_file"
assert updated_input[0]["content"][0]["file_id"] == "file-ECBPW7ML9g7XHdwGgUPZaM"
assert updated_input[0]["content"][1]["type"] == "input_text"
assert (
updated_input[0]["content"][1]["text"]
== "What is the first dragon in the book?"
)
assert updated_input[0]["content"][1]["text"] == "What is the first dragon in the book?"
def test_update_responses_input_with_regular_file_id():
@ -965,10 +955,7 @@ def test_update_responses_tools_with_model_file_id_mapping():
# Verify the file IDs were mapped to provider-specific file IDs
assert updated_tools[0]["type"] == "code_interpreter"
assert updated_tools[0]["container"]["file_ids"] == [
"openai_file_abc",
"openai_file_def",
]
assert updated_tools[0]["container"]["file_ids"] == ["openai_file_abc", "openai_file_def"]
def test_update_responses_tools_without_mapping():
@ -1039,10 +1026,7 @@ def test_update_responses_tools_with_mixed_file_ids():
)
# Verify managed file ID was mapped and regular file ID was kept
assert updated_tools[0]["container"]["file_ids"] == [
"openai_file_abc",
regular_file_id,
]
assert updated_tools[0]["container"]["file_ids"] == ["openai_file_abc", regular_file_id]
def test_get_file_ids_from_responses_tools():
@ -1371,9 +1355,7 @@ async def test_store_unified_file_id_with_none_file_object():
from litellm.proxy._types import UserAPIKeyAuth
prisma_client = AsyncMock()
prisma_client.db.litellm_managedfiletable.create = AsyncMock(
return_value=MagicMock()
)
prisma_client.db.litellm_managedfiletable.create = AsyncMock(return_value=MagicMock())
internal_usage_cache = MagicMock()
internal_usage_cache.async_set_cache = AsyncMock()
@ -1411,22 +1393,18 @@ async def test_afile_delete_returns_provider_response_when_stored_file_object_no
prisma_client = AsyncMock()
db_record = MagicMock()
db_record.model_mappings = '{"model-123": "file-provider-xyz"}'
prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(
return_value=db_record
)
prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(return_value=db_record)
prisma_client.db.litellm_managedfiletable.delete = AsyncMock()
internal_usage_cache = MagicMock()
internal_usage_cache.async_get_cache = AsyncMock(
return_value={
"unified_file_id": unified_file_id,
"model_mappings": {"model-123": "file-provider-xyz"},
"flat_model_file_ids": ["file-provider-xyz"],
"file_object": None,
"created_by": "test-user",
"updated_by": "test-user",
}
)
internal_usage_cache.async_get_cache = AsyncMock(return_value={
"unified_file_id": unified_file_id,
"model_mappings": {"model-123": "file-provider-xyz"},
"flat_model_file_ids": ["file-provider-xyz"],
"file_object": None,
"created_by": "test-user",
"updated_by": "test-user",
})
internal_usage_cache.async_set_cache = AsyncMock()
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
@ -1494,12 +1472,10 @@ async def test_afile_retrieve_fetches_from_provider_when_file_object_none():
)
mock_router = MagicMock()
mock_router.get_deployment_credentials_with_provider = MagicMock(
return_value={
"api_key": "test-key",
"api_base": "https://api.openai.com",
}
)
mock_router.get_deployment_credentials_with_provider = MagicMock(return_value={
"api_key": "test-key",
"api_base": "https://api.openai.com",
})
with patch("litellm.afile_retrieve", new_callable=AsyncMock) as mock_afile_retrieve:
mock_afile_retrieve.return_value = provider_file_response
@ -1624,33 +1600,29 @@ async def test_list_batches_from_managed_objects_table():
batch_record_1 = MagicMock()
batch_record_1.unified_object_id = "unified-batch-id-1"
batch_record_1.file_object = json.dumps(
{
"id": "batch_abc123",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "completed",
"created_at": 1234567890,
"input_file_id": "file-input-1",
"request_counts": {"total": 1, "completed": 1, "failed": 0},
}
)
batch_record_1.file_object = json.dumps({
"id": "batch_abc123",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "completed",
"created_at": 1234567890,
"input_file_id": "file-input-1",
"request_counts": {"total": 1, "completed": 1, "failed": 0},
})
batch_record_2 = MagicMock()
batch_record_2.unified_object_id = "unified-batch-id-2"
batch_record_2.file_object = json.dumps(
{
"id": "batch_xyz789",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "in_progress",
"created_at": 1234567891,
"input_file_id": "file-input-2",
"request_counts": {"total": 5, "completed": 2, "failed": 0},
}
)
batch_record_2.file_object = json.dumps({
"id": "batch_xyz789",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "in_progress",
"created_at": 1234567891,
"input_file_id": "file-input-2",
"request_counts": {"total": 5, "completed": 2, "failed": 0},
})
prisma_client.db.litellm_managedobjecttable.find_many.return_value = [
batch_record_1,
@ -1713,7 +1685,6 @@ async def test_list_batches_from_managed_objects_table_empty_list():
def _create_unified_batch_id(model_id: str, batch_id: str) -> str:
import base64
unified_str = f"litellm_proxy;model_id:{model_id};llm_batch_id:{batch_id}"
return base64.urlsafe_b64encode(unified_str.encode()).decode().rstrip("=")
@ -1769,7 +1740,6 @@ async def test_list_batches_from_managed_objects_table_target_model_name_filter_
# Verify find_many was NOT called since exception is raised before database query
prisma_client.db.litellm_managedobjecttable.find_many.assert_not_called()
@pytest.mark.asyncio
async def test_list_batches_from_managed_objects_table_filters_by_created_by():
from litellm.proxy._types import UserAPIKeyAuth
@ -1779,34 +1749,30 @@ async def test_list_batches_from_managed_objects_table_filters_by_created_by():
# Create batch for user1
batch_user1 = MagicMock()
batch_user1.unified_object_id = "unified-batch-user1"
batch_user1.file_object = json.dumps(
{
"id": "batch_user1_abc",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "completed",
"created_at": 1234567890,
"input_file_id": "file-input-user1",
"request_counts": {"total": 1, "completed": 1, "failed": 0},
}
)
batch_user1.file_object = json.dumps({
"id": "batch_user1_abc",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "completed",
"created_at": 1234567890,
"input_file_id": "file-input-user1",
"request_counts": {"total": 1, "completed": 1, "failed": 0},
})
# Create batch for user2
batch_user2 = MagicMock()
batch_user2.unified_object_id = "unified-batch-user2"
batch_user2.file_object = json.dumps(
{
"id": "batch_user2_xyz",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "completed",
"created_at": 1234567891,
"input_file_id": "file-input-user2",
"request_counts": {"total": 2, "completed": 2, "failed": 0},
}
)
batch_user2.file_object = json.dumps({
"id": "batch_user2_xyz",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "completed",
"created_at": 1234567891,
"input_file_id": "file-input-user2",
"request_counts": {"total": 2, "completed": 2, "failed": 0},
})
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
DualCache(), prisma_client=prisma_client
@ -1913,9 +1879,7 @@ async def test_user_b_cannot_retrieve_user_a_batch():
)
# User B tries to retrieve User A's batch
unified_batch_id = (
"bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
)
unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
with pytest.raises(HTTPException) as exc_info:
await proxy_managed_files.async_pre_call_hook(
@ -1950,9 +1914,7 @@ async def test_user_b_cannot_cancel_user_a_batch():
)
# User B tries to cancel User A's batch
unified_batch_id = (
"bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
)
unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
with pytest.raises(HTTPException) as exc_info:
await proxy_managed_files.async_pre_call_hook(
@ -1990,9 +1952,7 @@ async def test_user_a_can_retrieve_own_batch():
)
# User A retrieves their own batch
unified_batch_id = (
"bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
)
unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
# Should not raise an exception
result = await proxy_managed_files.async_pre_call_hook(
@ -2028,9 +1988,7 @@ async def test_user_b_cannot_retrieve_user_a_file():
)
# User B tries to retrieve User A's file
unified_file_id = (
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
)
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
with pytest.raises(HTTPException) as exc_info:
await proxy_managed_files.async_pre_call_hook(
@ -2065,9 +2023,7 @@ async def test_user_b_cannot_download_user_a_file_content():
)
# User B tries to download User A's file content
unified_file_id = (
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
)
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
with pytest.raises(HTTPException) as exc_info:
await proxy_managed_files.async_pre_call_hook(
@ -2102,9 +2058,7 @@ async def test_user_b_cannot_delete_user_a_file():
)
# User B tries to delete User A's file
unified_file_id = (
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
)
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
with pytest.raises(HTTPException) as exc_info:
await proxy_managed_files.async_pre_call_hook(
@ -2135,16 +2089,14 @@ async def test_user_a_can_retrieve_own_file():
file_record = MagicMock()
file_record.created_by = "user_a_id"
file_record.model_mappings = '{"model-123": "file-abc123"}'
file_record.file_object = json.dumps(
{
"id": "file-abc123",
"object": "file",
"bytes": 1234,
"created_at": 1234567890,
"filename": "test.jsonl",
"purpose": "batch",
}
)
file_record.file_object = json.dumps({
"id": "file-abc123",
"object": "file",
"bytes": 1234,
"created_at": 1234567890,
"filename": "test.jsonl",
"purpose": "batch",
})
prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
@ -2152,9 +2104,7 @@ async def test_user_a_can_retrieve_own_file():
)
# User A retrieves their own file
unified_file_id = (
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
)
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
# Should not raise an exception
result = await proxy_managed_files.async_pre_call_hook(
@ -2184,18 +2134,16 @@ async def test_list_batches_only_returns_user_own_batches():
# Create batches for User A
batch_user_a = MagicMock()
batch_user_a.unified_object_id = "batch-user-a"
batch_user_a.file_object = json.dumps(
{
"id": "batch_a",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "completed",
"created_at": 1234567890,
"input_file_id": "file-a",
"request_counts": {"total": 1, "completed": 1, "failed": 0},
}
)
batch_user_a.file_object = json.dumps({
"id": "batch_a",
"object": "batch",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"status": "completed",
"created_at": 1234567890,
"input_file_id": "file-a",
"request_counts": {"total": 1, "completed": 1, "failed": 0},
})
# Mock database to only return User A's batches
prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user_a]
@ -2243,14 +2191,14 @@ async def test_same_user_different_keys_can_access_batch():
DualCache(), prisma_client=prisma_client
)
unified_batch_id = (
"bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
)
unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
# First API key for User A retrieves the batch
result1 = await proxy_managed_files.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(
user_id="user_a_id", api_key="key-1", parent_otel_span=MagicMock()
user_id="user_a_id",
api_key="key-1",
parent_otel_span=MagicMock()
),
cache=MagicMock(),
data={"batch_id": unified_batch_id},
@ -2262,7 +2210,9 @@ async def test_same_user_different_keys_can_access_batch():
# Second API key for the same User A retrieves the batch
result2 = await proxy_managed_files.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(
user_id="user_a_id", api_key="key-2", parent_otel_span=MagicMock()
user_id="user_a_id",
api_key="key-2",
parent_otel_span=MagicMock()
),
cache=MagicMock(),
data={"batch_id": unified_batch_id},

View File

@ -31,16 +31,12 @@ class TestAvailableEnterpriseUsers:
self, client, mock_user_api_key_auth
):
"""Test when max_users is set and user count is within limit"""
with (
patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma,
patch(
"litellm.proxy.proxy_server.premium_user",
True,
),
patch(
"litellm.proxy.proxy_server.premium_user_data",
{"max_users": 10},
),
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.proxy_server.premium_user",
True,
), patch(
"litellm.proxy.proxy_server.premium_user_data",
{"max_users": 10},
):
# Mock database count
mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=5)
@ -70,16 +66,12 @@ class TestAvailableEnterpriseUsers:
self, client, mock_user_api_key_auth
):
"""Test when max_users is not set (premium_user_data is None or doesn't contain max_users)"""
with (
patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma,
patch(
"litellm.proxy.proxy_server.premium_user",
True,
),
patch(
"litellm.proxy.proxy_server.premium_user_data",
None,
),
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.proxy_server.premium_user",
True,
), patch(
"litellm.proxy.proxy_server.premium_user_data",
None,
):
# Mock database count
mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=3)
@ -107,16 +99,12 @@ class TestAvailableEnterpriseUsers:
self, client, mock_user_api_key_auth
):
"""Test the current bug where total_users_remaining can be negative"""
with (
patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma,
patch(
"litellm.proxy.proxy_server.premium_user",
True,
),
patch(
"litellm.proxy.proxy_server.premium_user_data",
{"key": "value"},
),
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.proxy_server.premium_user",
True,
), patch(
"litellm.proxy.proxy_server.premium_user_data",
{"key": "value"},
):
# Mock database count higher than max_users to trigger the bug
mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=8)
@ -152,15 +140,12 @@ class TestAvailableEnterpriseUsers:
"""Test when prisma_client is None (no database connection)"""
from litellm.proxy._types import CommonProxyErrors
with (
patch(
"litellm.proxy.proxy_server.prisma_client",
None,
),
patch(
"litellm.proxy.proxy_server.premium_user",
True,
),
with patch(
"litellm.proxy.proxy_server.prisma_client",
None,
), patch(
"litellm.proxy.proxy_server.premium_user",
True,
):
# Override the dependency
client.app.dependency_overrides[mock_user_api_key_auth] = lambda: {

View File

@ -801,9 +801,7 @@ async def test_list_projects_returns_timestamps():
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from litellm_enterprise.proxy.management_endpoints.project_endpoints import (
list_projects,
)
from litellm_enterprise.proxy.management_endpoints.project_endpoints import list_projects
from litellm.proxy._types import LiteLLM_ProjectTable
now = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)

View File

@ -14,6 +14,7 @@ from litellm.proxy.guardrails.guardrail_registry import (
)
from litellm.proxy.guardrails.guardrail_hooks.akto.akto import AktoGuardrail
# ---------------------------------------------------------------------------
# Registry tests
# ---------------------------------------------------------------------------

View File

@ -6,6 +6,7 @@ import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio

View File

@ -21,6 +21,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor
ContentFilterCategoryConfig,
)
# Test cases: (sentence, expected_result, reason)
TEST_CASES = [
# ALWAYS BLOCK - Explicit prohibited practices (1-10)

View File

@ -23,6 +23,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor
ContentFilterCategoryConfig,
)
# ── helpers ──────────────────────────────────────────────────────────────
POLICY_DIR = os.path.abspath(

View File

@ -28,6 +28,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor
ContentFilterCategoryConfig,
)
# ── helpers ──────────────────────────────────────────────────────────────
POLICY_DIR = os.path.abspath(

View File

@ -7,6 +7,7 @@ import sys
import traceback
from unittest.mock import AsyncMock, MagicMock, patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path

View File

@ -6,6 +6,7 @@ import os
import sys
import traceback
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path

View File

@ -11,6 +11,7 @@ import pytest
from litellm.types.utils import Embedding
from litellm.main import bedrock_embedding, embedding, EmbeddingResponse, Usage
_mock_model_id = (
"arn:aws:bedrock:us-east-1:123412341234:application-inference-profile/abc123123"
)

View File

@ -11,6 +11,7 @@ from litellm.llms.bedrock.common_utils import (
)
from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig
NOVA_ARN = "arn:aws:bedrock:us-east-1:123456789012:custom-model-deployment/a1b2c3d4e5f6"
NOVA_MODEL = f"bedrock/nova/{NOVA_ARN}"
NOVA2_MODEL = f"bedrock/nova-2/{NOVA_ARN}"

View File

@ -581,21 +581,17 @@ def test_vertex_ai_text_only_embedding_uses_embed_content():
def test_filter_embed_params_drops_unsupported():
"""Unsupported params like max_tokens should be filtered out."""
result = _filter_embed_params(
{"dimensions": 768, "max_tokens": 256, "temperature": 0.5}
)
result = _filter_embed_params({"dimensions": 768, "max_tokens": 256, "temperature": 0.5})
assert result == {"outputDimensionality": 768}
def test_filter_embed_params_keeps_supported():
"""All supported Gemini embedding params should pass through."""
result = _filter_embed_params(
{
"dimensions": 768,
"task_type": "RETRIEVAL_DOCUMENT",
"title": "My doc",
}
)
result = _filter_embed_params({
"dimensions": 768,
"task_type": "RETRIEVAL_DOCUMENT",
"title": "My doc",
})
assert result == {
"outputDimensionality": 768,
"taskType": "RETRIEVAL_DOCUMENT",

View File

@ -11,6 +11,7 @@ import pytest
from litellm import get_model_info
MODEL_NAME = "nvidia.nemotron-super-3-120b"

View File

@ -12,6 +12,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
BedrockConverseMessagesProcessor,
)
MODEL = "anthropic.claude-v2"
PROVIDER = "bedrock_converse"
@ -545,9 +546,9 @@ def test_bedrock_converse_sorts_text_before_tooluse_sync():
tool_indices = [i for i, b in enumerate(content) if "toolUse" in b]
# All text blocks must come before all toolUse blocks
assert max(text_indices) < min(
tool_indices
), f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}"
assert max(text_indices) < min(tool_indices), (
f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}"
)
@pytest.mark.asyncio
@ -566,9 +567,9 @@ async def test_bedrock_converse_sorts_text_before_tooluse_async():
text_indices = [i for i, b in enumerate(content) if "text" in b]
tool_indices = [i for i, b in enumerate(content) if "toolUse" in b]
assert max(text_indices) < min(
tool_indices
), f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}"
assert max(text_indices) < min(tool_indices), (
f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}"
)
@pytest.mark.asyncio
@ -576,9 +577,7 @@ async def test_bedrock_converse_content_ordering_sync_async_parity():
"""Sync and async paths should produce identical content block ordering."""
messages = _make_tooluse_before_text_messages()
sync_result = _bedrock_converse_messages_pt(messages, MODEL, PROVIDER)
async_result = (
await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
messages, MODEL, PROVIDER
)
async_result = await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
messages, MODEL, PROVIDER
)
assert sync_result == async_result

View File

@ -10,6 +10,7 @@ from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
load_dotenv()
import io

View File

@ -14,6 +14,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
# Fake Vertex AI Gemini response for mocking
FAKE_VERTEX_GEMINI_RESPONSE = {
"candidates": [

View File

@ -12,6 +12,7 @@ from litellm.responses.litellm_completion_transformation.transformation import (
)
from litellm.types.utils import ModelResponse
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm.integrations.custom_logger import CustomLogger

View File

@ -387,7 +387,9 @@ def test_process_chunk_completed_response_updates_id_and_usage_cost(monkeypatch)
# Chunk must include a top-level "response" key so BaseResponsesAPIStreamingIterator
# runs _update_responses_api_response_id_with_model_id (see streaming_iterator.py).
event = iterator._process_chunk(
json.dumps({"type": "response.completed", "response": {"id": "resp_live"}})
json.dumps(
{"type": "response.completed", "response": {"id": "resp_live"}}
)
)
finally:
litellm.include_cost_in_streaming_usage = original_include_cost

View File

@ -22,8 +22,10 @@ sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm import completion
# Large document for caching tests (needs 1024+ tokens for Claude models)
LARGE_DOCUMENT_FOR_CACHING = """
LARGE_DOCUMENT_FOR_CACHING = (
"""
This is a comprehensive legal agreement between Party A and Party B.
ARTICLE 1: DEFINITIONS
@ -75,7 +77,9 @@ ARTICLE 9: GENERAL PROVISIONS
9.5 Waiver of any provision shall not constitute ongoing waiver.
IN WITNESS WHEREOF, the parties have executed this Agreement.
""" * 8 # Repeat to ensure we have enough tokens (need 1024+ for Claude models)
"""
* 8
) # Repeat to ensure we have enough tokens (need 1024+ for Claude models)
class TestBedrockAnthropicPromptCachingRegression:

View File

@ -3,6 +3,7 @@ import pytest
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path

View File

@ -1,7 +1,6 @@
"""
Tests for Crusoe provider integration
"""
import os
from unittest import mock
@ -97,9 +96,9 @@ def test_crusoe_models_configuration():
for model in crusoe_models:
model_info = get_model_info(model)
assert model_info is not None, f"Model info not found for {model}"
assert (
model_info.get("litellm_provider") == "crusoe"
), f"{model} should have crusoe as provider"
assert model_info.get("litellm_provider") == "crusoe", (
f"{model} should have crusoe as provider"
)
assert model_info.get("mode") == "chat", f"{model} should be in chat mode"
finally:
litellm.model_cost = original_model_cost

View File

@ -1362,12 +1362,8 @@ def test_anthropic_thinking_param_to_gemini_3_provider_defaults():
)
# For Gemini 3, should not force thinkingLevel by default
assert (
"thinkingLevel" not in result
), "Should not force thinkingLevel for Gemini 3"
assert (
"thinkingBudget" not in result
), "Should NOT have thinkingBudget for Gemini 3"
assert "thinkingLevel" not in result, "Should not force thinkingLevel for Gemini 3"
assert "thinkingBudget" not in result, "Should NOT have thinkingBudget for Gemini 3"
assert result["includeThoughts"] is True
# Test 2: Anthropic thinking disabled for Gemini 3
@ -1399,10 +1395,7 @@ def test_anthropic_thinking_param_to_gemini_3_provider_defaults():
)
assert result_zero["includeThoughts"] is False
assert (
"thinkingLevel" not in result_zero
or result_zero.get("thinkingLevel") is None
)
assert "thinkingLevel" not in result_zero or result_zero.get("thinkingLevel") is None
# Test 4: Gemini 3 flash-preview should also follow provider defaults by default
result_gemini3flashpreview = VertexGeminiConfig._map_thinking_param(
@ -1532,12 +1525,8 @@ def test_anthropic_thinking_param_via_map_openai_params():
# Check that thinkingConfig was created without forced thinkingLevel
assert "thinkingConfig" in result, "Should have thinkingConfig in optional_params"
thinking_config = result["thinkingConfig"]
assert (
"thinkingLevel" not in thinking_config
), "Should not force thinkingLevel for Gemini 3 by default"
assert (
"thinkingBudget" not in thinking_config
), "Should NOT have thinkingBudget for Gemini 3"
assert "thinkingLevel" not in thinking_config, "Should not force thinkingLevel for Gemini 3 by default"
assert "thinkingBudget" not in thinking_config, "Should NOT have thinkingBudget for Gemini 3"
assert thinking_config["includeThoughts"] is True
# Test with Gemini 2 model

View File

@ -2,6 +2,7 @@ import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio

View File

@ -38,6 +38,7 @@ from litellm.llms.vertex_ai.gemini.transformation import (
)
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
litellm.num_retries = 3
litellm.cache = None
user_message = "Write a short poem about the sky"
@ -1103,7 +1104,9 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs):
{
"content": {
"role": "model",
"parts": [{"text": """{
"parts": [
{
"text": """{
"recipes": [
{"recipe_name": "Chocolate Chip Cookies"},
{"recipe_name": "Oatmeal Raisin Cookies"},
@ -1111,7 +1114,9 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs):
{"recipe_name": "Sugar Cookies"},
{"recipe_name": "Snickerdoodles"}
]
}"""}],
}"""
}
],
},
"finishReason": "STOP",
"safetyRatings": [

View File

@ -450,9 +450,12 @@ def c():
litellm.enable_caching_on_provider_specific_optional_params = False
embedding_large_text = """
embedding_large_text = (
"""
small text
""" * 5
"""
* 5
)
# # test_caching_with_models()

View File

@ -1638,13 +1638,15 @@ def custom_callback(
#################################################
print(f"""
print(
f"""
Model: {model},
Messages: {messages},
User: {user},
Seed: {kwargs["seed"]},
temperature: {kwargs["temperature"]},
""")
"""
)
assert kwargs["user"] == "ishaans app"
assert kwargs["model"] == "gpt-3.5-turbo-1106"

View File

@ -93,7 +93,8 @@ class testCustomCallbackProxy(CustomLogger):
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
print(f"""
print(
f"""
Model: {model},
Messages: {messages},
User: {user},
@ -101,7 +102,8 @@ class testCustomCallbackProxy(CustomLogger):
Cost: {cost},
Response: {response}
Proxy Metadata: {metadata}
""")
"""
)
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):

Some files were not shown because too many files have changed in this diff Show More