Drop dep bumps + black-26 reformat to clear fork CI policy
PR was blocked by .github/workflows/guard-fork-dependencies.yml: fork PRs cannot modify uv.lock. Reverting: - uv.lock + pyproject.toml black bump (24.10.0 -> 26.3.1) and the 295 files of mechanical Black 26 reformat coupled to it - pyproject.toml diskcache extra change (kept the runtime mitigation in litellm/caching/disk_cache.py via JSONDisk) Kept: - Dockerfile cache narrowing (drops ~660 MB of uv build cache that surfaced cached setuptools as CVE findings) - litellm/caching/disk_cache.py: dc.JSONDisk to neutralize CVE-2025-69872 - ui/litellm-dashboard/package-lock.json + litellm-js/spend-logs/package-lock.json: next/postcss/hono/uuid CVE bumps (these are not blocked by the fork guard) - tests/test_litellm/caching/test_disk_cache.py - tests/code_coverage_tests/liccheck.ini: harmless black authorization Black + gitpython + langchain dep upgrades will need a follow-up from a maintainer pushing a branch in the canonical BerriAI/litellm repo. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
63bda3f001
commit
5bafa8b3a2
@ -2,7 +2,6 @@ import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
|
||||
# Asynchronously fetch data from a given URL
|
||||
async def fetch_data(url):
|
||||
try:
|
||||
@ -16,24 +15,22 @@ async def fetch_data(url):
|
||||
resp_json = await resp.json()
|
||||
print("Fetch the data from URL.")
|
||||
# Return the 'data' field from the JSON response
|
||||
return resp_json["data"]
|
||||
return resp_json['data']
|
||||
except Exception as e:
|
||||
# Print an error message if fetching data fails
|
||||
print("Error fetching data from URL:", e)
|
||||
return None
|
||||
|
||||
|
||||
# Synchronize local data with remote data
|
||||
def sync_local_data_with_remote(local_data, remote_data):
|
||||
# Update existing keys in local_data with values from remote_data
|
||||
for key in set(local_data) & set(remote_data):
|
||||
for key in (set(local_data) & set(remote_data)):
|
||||
local_data[key].update(remote_data[key])
|
||||
|
||||
# Add new keys from remote_data to local_data
|
||||
for key in set(remote_data) - set(local_data):
|
||||
for key in (set(remote_data) - set(local_data)):
|
||||
local_data[key] = remote_data[key]
|
||||
|
||||
|
||||
# Write data to the json file
|
||||
def write_to_file(file_path, data):
|
||||
try:
|
||||
@ -46,7 +43,6 @@ def write_to_file(file_path, data):
|
||||
# Print an error message if writing to file fails
|
||||
print("Error updating JSON file:", e)
|
||||
|
||||
|
||||
# Update the existing models and add the missing models for OpenRouter
|
||||
def transform_openrouter_data(data):
|
||||
transformed = {}
|
||||
@ -58,41 +54,33 @@ def transform_openrouter_data(data):
|
||||
}
|
||||
|
||||
# Add 'max_output_tokens' as a field if it is not None
|
||||
if (
|
||||
"top_provider" in row
|
||||
and "max_completion_tokens" in row["top_provider"]
|
||||
and row["top_provider"]["max_completion_tokens"] is not None
|
||||
):
|
||||
obj["max_output_tokens"] = int(row["top_provider"]["max_completion_tokens"])
|
||||
if "top_provider" in row and "max_completion_tokens" in row["top_provider"] and row["top_provider"]["max_completion_tokens"] is not None:
|
||||
obj['max_output_tokens'] = int(row["top_provider"]["max_completion_tokens"])
|
||||
|
||||
# Add the field 'output_cost_per_token'
|
||||
obj.update(
|
||||
{
|
||||
"output_cost_per_token": float(row["pricing"]["completion"]),
|
||||
}
|
||||
)
|
||||
obj.update({
|
||||
"output_cost_per_token": float(row["pricing"]["completion"]),
|
||||
})
|
||||
|
||||
# Add field 'input_cost_per_image' if it exists and is non-zero
|
||||
if (
|
||||
"pricing" in row
|
||||
and "image" in row["pricing"]
|
||||
and float(row["pricing"]["image"]) != 0.0
|
||||
):
|
||||
obj["input_cost_per_image"] = float(row["pricing"]["image"])
|
||||
if "pricing" in row and "image" in row["pricing"] and float(row["pricing"]["image"]) != 0.0:
|
||||
obj['input_cost_per_image'] = float(row["pricing"]["image"])
|
||||
|
||||
# Add the fields 'litellm_provider' and 'mode'
|
||||
obj.update({"litellm_provider": "openrouter", "mode": "chat"})
|
||||
obj.update({
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
})
|
||||
|
||||
# Add the 'supports_vision' field if the modality is 'multimodal'
|
||||
if row.get("architecture", {}).get("modality") == "multimodal":
|
||||
obj["supports_vision"] = True
|
||||
if row.get('architecture', {}).get('modality') == 'multimodal':
|
||||
obj['supports_vision'] = True
|
||||
|
||||
# Use a composite key to store the transformed object
|
||||
transformed[f'openrouter/{row["id"]}'] = obj
|
||||
|
||||
return transformed
|
||||
|
||||
|
||||
# Update the existing models and add the missing models for Vercel AI Gateway
|
||||
def transform_vercel_ai_gateway_data(data):
|
||||
transformed = {}
|
||||
@ -101,27 +89,17 @@ def transform_vercel_ai_gateway_data(data):
|
||||
"max_tokens": row["context_window"],
|
||||
"input_cost_per_token": float(row["pricing"]["input"]),
|
||||
"output_cost_per_token": float(row["pricing"]["output"]),
|
||||
"max_output_tokens": row["max_tokens"],
|
||||
"max_input_tokens": row["context_window"],
|
||||
'max_output_tokens': row['max_tokens'],
|
||||
'max_input_tokens': row["context_window"],
|
||||
}
|
||||
|
||||
# Handle cache pricing if available
|
||||
if "pricing" in row:
|
||||
if (
|
||||
"input_cache_read" in row["pricing"]
|
||||
and row["pricing"]["input_cache_read"] is not None
|
||||
):
|
||||
obj["cache_read_input_token_cost"] = float(
|
||||
f"{float(row['pricing']['input_cache_read']):e}"
|
||||
)
|
||||
if "input_cache_read" in row["pricing"] and row["pricing"]["input_cache_read"] is not None:
|
||||
obj['cache_read_input_token_cost'] = float(f"{float(row['pricing']['input_cache_read']):e}")
|
||||
|
||||
if (
|
||||
"input_cache_write" in row["pricing"]
|
||||
and row["pricing"]["input_cache_write"] is not None
|
||||
):
|
||||
obj["cache_creation_input_token_cost"] = float(
|
||||
f"{float(row['pricing']['input_cache_write']):e}"
|
||||
)
|
||||
if "input_cache_write" in row["pricing"] and row["pricing"]["input_cache_write"] is not None:
|
||||
obj['cache_creation_input_token_cost'] = float(f"{float(row['pricing']['input_cache_write']):e}")
|
||||
|
||||
mode = "embedding" if "embedding" in row["id"].lower() else "chat"
|
||||
|
||||
@ -148,17 +126,10 @@ def load_local_data(file_path):
|
||||
print("Error decoding JSON:", e)
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
local_file_path = (
|
||||
"model_prices_and_context_window.json" # Path to the local data file
|
||||
)
|
||||
openrouter_url = (
|
||||
"https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
|
||||
)
|
||||
vercel_ai_gateway_url = (
|
||||
"https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
|
||||
)
|
||||
local_file_path = "model_prices_and_context_window.json" # Path to the local data file
|
||||
openrouter_url = "https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
|
||||
vercel_ai_gateway_url = "https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
|
||||
|
||||
# Load local data from file
|
||||
local_data = load_local_data(local_file_path)
|
||||
@ -183,7 +154,6 @@ def main():
|
||||
else:
|
||||
print("Failed to fetch model data from either local file or URL.")
|
||||
|
||||
|
||||
# Entry point of the script
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
402
.github/workflows/run_llm_translation_tests.py
vendored
402
.github/workflows/run_llm_translation_tests.py
vendored
@ -16,55 +16,52 @@ from pathlib import Path
|
||||
import json
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
class Colors:
|
||||
GREEN = "\033[92m"
|
||||
RED = "\033[91m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
PURPLE = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
PURPLE = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def print_colored(message: str, color: str = Colors.RESET):
|
||||
"""Print colored message to terminal"""
|
||||
print(f"{color}{message}{Colors.RESET}")
|
||||
|
||||
|
||||
def get_provider_from_test_file(test_file: str) -> str:
|
||||
"""Map test file names to provider names"""
|
||||
provider_mapping = {
|
||||
"test_anthropic": "Anthropic",
|
||||
"test_azure": "Azure",
|
||||
"test_bedrock": "AWS Bedrock",
|
||||
"test_openai": "OpenAI",
|
||||
"test_vertex": "Google Vertex AI",
|
||||
"test_gemini": "Google Vertex AI",
|
||||
"test_cohere": "Cohere",
|
||||
"test_databricks": "Databricks",
|
||||
"test_groq": "Groq",
|
||||
"test_together": "Together AI",
|
||||
"test_mistral": "Mistral",
|
||||
"test_deepseek": "DeepSeek",
|
||||
"test_replicate": "Replicate",
|
||||
"test_huggingface": "HuggingFace",
|
||||
"test_fireworks": "Fireworks AI",
|
||||
"test_perplexity": "Perplexity",
|
||||
"test_cloudflare": "Cloudflare",
|
||||
"test_voyage": "Voyage AI",
|
||||
"test_xai": "xAI",
|
||||
"test_nvidia": "NVIDIA",
|
||||
"test_watsonx": "IBM watsonx",
|
||||
"test_azure_ai": "Azure AI",
|
||||
"test_snowflake": "Snowflake",
|
||||
"test_infinity": "Infinity",
|
||||
"test_jina": "Jina AI",
|
||||
"test_deepgram": "Deepgram",
|
||||
"test_clarifai": "Clarifai",
|
||||
"test_triton": "Triton",
|
||||
'test_anthropic': 'Anthropic',
|
||||
'test_azure': 'Azure',
|
||||
'test_bedrock': 'AWS Bedrock',
|
||||
'test_openai': 'OpenAI',
|
||||
'test_vertex': 'Google Vertex AI',
|
||||
'test_gemini': 'Google Vertex AI',
|
||||
'test_cohere': 'Cohere',
|
||||
'test_databricks': 'Databricks',
|
||||
'test_groq': 'Groq',
|
||||
'test_together': 'Together AI',
|
||||
'test_mistral': 'Mistral',
|
||||
'test_deepseek': 'DeepSeek',
|
||||
'test_replicate': 'Replicate',
|
||||
'test_huggingface': 'HuggingFace',
|
||||
'test_fireworks': 'Fireworks AI',
|
||||
'test_perplexity': 'Perplexity',
|
||||
'test_cloudflare': 'Cloudflare',
|
||||
'test_voyage': 'Voyage AI',
|
||||
'test_xai': 'xAI',
|
||||
'test_nvidia': 'NVIDIA',
|
||||
'test_watsonx': 'IBM watsonx',
|
||||
'test_azure_ai': 'Azure AI',
|
||||
'test_snowflake': 'Snowflake',
|
||||
'test_infinity': 'Infinity',
|
||||
'test_jina': 'Jina AI',
|
||||
'test_deepgram': 'Deepgram',
|
||||
'test_clarifai': 'Clarifai',
|
||||
'test_triton': 'Triton',
|
||||
}
|
||||
|
||||
for key, provider in provider_mapping.items():
|
||||
@ -72,19 +69,11 @@ def get_provider_from_test_file(test_file: str) -> str:
|
||||
return provider
|
||||
|
||||
# For cross-provider test files
|
||||
if any(
|
||||
name in test_file
|
||||
for name in [
|
||||
"test_optional_params",
|
||||
"test_prompt_factory",
|
||||
"test_router",
|
||||
"test_text_completion",
|
||||
]
|
||||
):
|
||||
return f"Cross-Provider Tests ({test_file})"
|
||||
|
||||
return "Other Tests"
|
||||
if any(name in test_file for name in ['test_optional_params', 'test_prompt_factory',
|
||||
'test_router', 'test_text_completion']):
|
||||
return f'Cross-Provider Tests ({test_file})'
|
||||
|
||||
return 'Other Tests'
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
"""Format duration in human-readable format"""
|
||||
@ -100,17 +89,15 @@ def format_duration(seconds: float) -> str:
|
||||
return f"{hours}h {minutes}m"
|
||||
|
||||
|
||||
def generate_markdown_report(
|
||||
junit_xml_path: str, output_path: str, tag: str = None, commit: str = None
|
||||
):
|
||||
def generate_markdown_report(junit_xml_path: str, output_path: str, tag: str = None, commit: str = None):
|
||||
"""Generate a beautiful markdown report from JUnit XML"""
|
||||
try:
|
||||
tree = ET.parse(junit_xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# Handle both testsuite and testsuites root
|
||||
if root.tag == "testsuites":
|
||||
suites = root.findall("testsuite")
|
||||
if root.tag == 'testsuites':
|
||||
suites = root.findall('testsuite')
|
||||
else:
|
||||
suites = [root]
|
||||
|
||||
@ -122,87 +109,75 @@ def generate_markdown_report(
|
||||
total_time = 0.0
|
||||
|
||||
# Provider breakdown
|
||||
provider_stats = defaultdict(
|
||||
lambda: {"passed": 0, "failed": 0, "skipped": 0, "errors": 0, "time": 0.0}
|
||||
)
|
||||
provider_stats = defaultdict(lambda: {'passed': 0, 'failed': 0, 'skipped': 0, 'errors': 0, 'time': 0.0})
|
||||
provider_tests = defaultdict(list)
|
||||
|
||||
for suite in suites:
|
||||
total_tests += int(suite.get("tests", 0))
|
||||
total_failures += int(suite.get("failures", 0))
|
||||
total_errors += int(suite.get("errors", 0))
|
||||
total_skipped += int(suite.get("skipped", 0))
|
||||
total_time += float(suite.get("time", 0))
|
||||
total_tests += int(suite.get('tests', 0))
|
||||
total_failures += int(suite.get('failures', 0))
|
||||
total_errors += int(suite.get('errors', 0))
|
||||
total_skipped += int(suite.get('skipped', 0))
|
||||
total_time += float(suite.get('time', 0))
|
||||
|
||||
for testcase in suite.findall("testcase"):
|
||||
classname = testcase.get("classname", "")
|
||||
test_name = testcase.get("name", "")
|
||||
test_time = float(testcase.get("time", 0))
|
||||
for testcase in suite.findall('testcase'):
|
||||
classname = testcase.get('classname', '')
|
||||
test_name = testcase.get('name', '')
|
||||
test_time = float(testcase.get('time', 0))
|
||||
|
||||
# Extract test file name from classname
|
||||
if "." in classname:
|
||||
parts = classname.split(".")
|
||||
test_file = parts[-2] if len(parts) > 1 else "unknown"
|
||||
if '.' in classname:
|
||||
parts = classname.split('.')
|
||||
test_file = parts[-2] if len(parts) > 1 else 'unknown'
|
||||
else:
|
||||
test_file = "unknown"
|
||||
test_file = 'unknown'
|
||||
|
||||
provider = get_provider_from_test_file(test_file)
|
||||
provider_stats[provider]["time"] += test_time
|
||||
provider_stats[provider]['time'] += test_time
|
||||
|
||||
# Check test status
|
||||
if testcase.find("failure") is not None:
|
||||
provider_stats[provider]["failed"] += 1
|
||||
failure = testcase.find("failure")
|
||||
failure_msg = (
|
||||
failure.get("message", "") if failure is not None else ""
|
||||
)
|
||||
provider_tests[provider].append(
|
||||
{
|
||||
"name": test_name,
|
||||
"status": "FAILED",
|
||||
"time": test_time,
|
||||
"message": failure_msg,
|
||||
}
|
||||
)
|
||||
elif testcase.find("error") is not None:
|
||||
provider_stats[provider]["errors"] += 1
|
||||
error = testcase.find("error")
|
||||
error_msg = error.get("message", "") if error is not None else ""
|
||||
provider_tests[provider].append(
|
||||
{
|
||||
"name": test_name,
|
||||
"status": "ERROR",
|
||||
"time": test_time,
|
||||
"message": error_msg,
|
||||
}
|
||||
)
|
||||
elif testcase.find("skipped") is not None:
|
||||
provider_stats[provider]["skipped"] += 1
|
||||
skip = testcase.find("skipped")
|
||||
skip_msg = skip.get("message", "") if skip is not None else ""
|
||||
provider_tests[provider].append(
|
||||
{
|
||||
"name": test_name,
|
||||
"status": "SKIPPED",
|
||||
"time": test_time,
|
||||
"message": skip_msg,
|
||||
}
|
||||
)
|
||||
if testcase.find('failure') is not None:
|
||||
provider_stats[provider]['failed'] += 1
|
||||
failure = testcase.find('failure')
|
||||
failure_msg = failure.get('message', '') if failure is not None else ''
|
||||
provider_tests[provider].append({
|
||||
'name': test_name,
|
||||
'status': 'FAILED',
|
||||
'time': test_time,
|
||||
'message': failure_msg
|
||||
})
|
||||
elif testcase.find('error') is not None:
|
||||
provider_stats[provider]['errors'] += 1
|
||||
error = testcase.find('error')
|
||||
error_msg = error.get('message', '') if error is not None else ''
|
||||
provider_tests[provider].append({
|
||||
'name': test_name,
|
||||
'status': 'ERROR',
|
||||
'time': test_time,
|
||||
'message': error_msg
|
||||
})
|
||||
elif testcase.find('skipped') is not None:
|
||||
provider_stats[provider]['skipped'] += 1
|
||||
skip = testcase.find('skipped')
|
||||
skip_msg = skip.get('message', '') if skip is not None else ''
|
||||
provider_tests[provider].append({
|
||||
'name': test_name,
|
||||
'status': 'SKIPPED',
|
||||
'time': test_time,
|
||||
'message': skip_msg
|
||||
})
|
||||
else:
|
||||
provider_stats[provider]["passed"] += 1
|
||||
provider_tests[provider].append(
|
||||
{
|
||||
"name": test_name,
|
||||
"status": "PASSED",
|
||||
"time": test_time,
|
||||
"message": "",
|
||||
}
|
||||
)
|
||||
provider_stats[provider]['passed'] += 1
|
||||
provider_tests[provider].append({
|
||||
'name': test_name,
|
||||
'status': 'PASSED',
|
||||
'time': test_time,
|
||||
'message': ''
|
||||
})
|
||||
|
||||
passed = total_tests - total_failures - total_errors - total_skipped
|
||||
|
||||
# Generate the markdown report
|
||||
with open(output_path, "w") as f:
|
||||
with open(output_path, 'w') as f:
|
||||
# Header
|
||||
f.write("# LLM Translation Test Results\n\n")
|
||||
|
||||
@ -211,9 +186,7 @@ def generate_markdown_report(
|
||||
f.write("| Field | Value |\n")
|
||||
f.write("|-------|-------|\n")
|
||||
f.write(f"| **Tag** | `{tag or 'N/A'}` |\n")
|
||||
f.write(
|
||||
f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n"
|
||||
)
|
||||
f.write(f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n")
|
||||
f.write(f"| **Commit** | `{commit or 'N/A'}` |\n")
|
||||
f.write(f"| **Duration** | {format_duration(total_time)} |\n")
|
||||
f.write("\n")
|
||||
@ -224,34 +197,23 @@ def generate_markdown_report(
|
||||
# Summary box
|
||||
f.write("```\n")
|
||||
f.write(f"Total Tests: {total_tests}\n")
|
||||
f.write(
|
||||
f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
||||
)
|
||||
f.write(
|
||||
f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
||||
)
|
||||
f.write(
|
||||
f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
||||
)
|
||||
f.write(
|
||||
f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
||||
)
|
||||
f.write(f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||
f.write(f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||
f.write(f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||
f.write(f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||
f.write("```\n\n")
|
||||
|
||||
|
||||
# Provider summary table
|
||||
f.write("## Results by Provider\n\n")
|
||||
f.write(
|
||||
"| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n"
|
||||
)
|
||||
f.write(
|
||||
"|----------|-------|------|------|-------|------|-----------|----------|"
|
||||
)
|
||||
f.write("| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n")
|
||||
f.write("|----------|-------|------|------|-------|------|-----------|----------|")
|
||||
|
||||
# Sort providers: specific providers first, then cross-provider tests
|
||||
sorted_providers = []
|
||||
cross_provider = []
|
||||
for p in sorted(provider_stats.keys()):
|
||||
if "Cross-Provider" in p or p == "Other Tests":
|
||||
if 'Cross-Provider' in p or p == 'Other Tests':
|
||||
cross_provider.append(p)
|
||||
else:
|
||||
sorted_providers.append(p)
|
||||
@ -260,17 +222,10 @@ def generate_markdown_report(
|
||||
|
||||
for provider in all_providers:
|
||||
stats = provider_stats[provider]
|
||||
total = (
|
||||
stats["passed"]
|
||||
+ stats["failed"]
|
||||
+ stats["errors"]
|
||||
+ stats["skipped"]
|
||||
)
|
||||
pass_rate = (stats["passed"] / total * 100) if total > 0 else 0
|
||||
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||
pass_rate = (stats['passed'] / total * 100) if total > 0 else 0
|
||||
|
||||
f.write(
|
||||
f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | "
|
||||
)
|
||||
f.write(f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | ")
|
||||
f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ")
|
||||
f.write(f"{format_duration(stats['time'])} |")
|
||||
|
||||
@ -280,58 +235,47 @@ def generate_markdown_report(
|
||||
for provider in sorted_providers:
|
||||
if provider_tests[provider]:
|
||||
stats = provider_stats[provider]
|
||||
total = (
|
||||
stats["passed"]
|
||||
+ stats["failed"]
|
||||
+ stats["errors"]
|
||||
+ stats["skipped"]
|
||||
)
|
||||
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||
|
||||
f.write(f"### {provider}\n\n")
|
||||
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
|
||||
f.write(
|
||||
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) "
|
||||
)
|
||||
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) ")
|
||||
f.write(f"in {format_duration(stats['time'])}\n\n")
|
||||
|
||||
# Group tests by status
|
||||
tests_by_status = defaultdict(list)
|
||||
for test in provider_tests[provider]:
|
||||
tests_by_status[test["status"]].append(test)
|
||||
tests_by_status[test['status']].append(test)
|
||||
|
||||
# Show failed tests first (if any)
|
||||
if tests_by_status["FAILED"]:
|
||||
if tests_by_status['FAILED']:
|
||||
f.write("<details>\n<summary>Failed Tests</summary>\n\n")
|
||||
for test in tests_by_status["FAILED"]:
|
||||
for test in tests_by_status['FAILED']:
|
||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||
if test["message"]:
|
||||
if test['message']:
|
||||
# Truncate long error messages
|
||||
msg = (
|
||||
test["message"][:200] + "..."
|
||||
if len(test["message"]) > 200
|
||||
else test["message"]
|
||||
)
|
||||
msg = test['message'][:200] + '...' if len(test['message']) > 200 else test['message']
|
||||
f.write(f" > {msg}\n")
|
||||
f.write("\n</details>\n\n")
|
||||
|
||||
# Show errors (if any)
|
||||
if tests_by_status["ERROR"]:
|
||||
if tests_by_status['ERROR']:
|
||||
f.write("<details>\n<summary>Error Tests</summary>\n\n")
|
||||
for test in tests_by_status["ERROR"]:
|
||||
for test in tests_by_status['ERROR']:
|
||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||
f.write("\n</details>\n\n")
|
||||
|
||||
# Show passed tests in collapsible section
|
||||
if tests_by_status["PASSED"]:
|
||||
if tests_by_status['PASSED']:
|
||||
f.write("<details>\n<summary>Passed Tests</summary>\n\n")
|
||||
for test in tests_by_status["PASSED"]:
|
||||
for test in tests_by_status['PASSED']:
|
||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||
f.write("\n</details>\n\n")
|
||||
|
||||
# Show skipped tests (if any)
|
||||
if tests_by_status["SKIPPED"]:
|
||||
if tests_by_status['SKIPPED']:
|
||||
f.write("<details>\n<summary>Skipped Tests</summary>\n\n")
|
||||
for test in tests_by_status["SKIPPED"]:
|
||||
for test in tests_by_status['SKIPPED']:
|
||||
f.write(f"- `{test['name']}`\n")
|
||||
f.write("\n</details>\n\n")
|
||||
|
||||
@ -341,43 +285,34 @@ def generate_markdown_report(
|
||||
for provider in cross_provider:
|
||||
if provider_tests[provider]:
|
||||
stats = provider_stats[provider]
|
||||
total = (
|
||||
stats["passed"]
|
||||
+ stats["failed"]
|
||||
+ stats["errors"]
|
||||
+ stats["skipped"]
|
||||
)
|
||||
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||
|
||||
f.write(f"#### {provider}\n\n")
|
||||
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
|
||||
f.write(
|
||||
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n"
|
||||
)
|
||||
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n")
|
||||
|
||||
# For cross-provider tests, just show counts
|
||||
f.write(f"- Passed: {stats['passed']}\n")
|
||||
if stats["failed"] > 0:
|
||||
if stats['failed'] > 0:
|
||||
f.write(f"- Failed: {stats['failed']}\n")
|
||||
if stats["errors"] > 0:
|
||||
if stats['errors'] > 0:
|
||||
f.write(f"- Errors: {stats['errors']}\n")
|
||||
if stats["skipped"] > 0:
|
||||
if stats['skipped'] > 0:
|
||||
f.write(f"- Skipped: {stats['skipped']}\n")
|
||||
f.write("\n")
|
||||
|
||||
|
||||
print_colored(f"Report generated: {output_path}", Colors.GREEN)
|
||||
|
||||
except Exception as e:
|
||||
print_colored(f"Error generating report: {e}", Colors.RED)
|
||||
raise
|
||||
|
||||
|
||||
def run_tests(
|
||||
test_path: str = "tests/llm_translation/",
|
||||
junit_xml: str = "test-results/junit.xml",
|
||||
report_path: str = "test-results/llm_translation_report.md",
|
||||
tag: str = None,
|
||||
commit: str = None,
|
||||
) -> int:
|
||||
def run_tests(test_path: str = "tests/llm_translation/",
|
||||
junit_xml: str = "test-results/junit.xml",
|
||||
report_path: str = "test-results/llm_translation_report.md",
|
||||
tag: str = None,
|
||||
commit: str = None) -> int:
|
||||
"""Run the LLM translation tests and generate report"""
|
||||
|
||||
# Create test results directory
|
||||
@ -390,32 +325,21 @@ def run_tests(
|
||||
|
||||
# Run pytest
|
||||
cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--no-sync",
|
||||
"pytest",
|
||||
test_path,
|
||||
"uv", "run", "--no-sync", "pytest", test_path,
|
||||
f"--junitxml={junit_xml}",
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"--maxfail=500",
|
||||
"-n",
|
||||
"auto",
|
||||
"-n", "auto"
|
||||
]
|
||||
|
||||
# Add timeout if pytest-timeout is installed
|
||||
try:
|
||||
subprocess.run(
|
||||
["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
subprocess.run(["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
|
||||
capture_output=True, check=True)
|
||||
cmd.extend(["--timeout=300"])
|
||||
except:
|
||||
print_colored(
|
||||
"Warning: pytest-timeout not installed, skipping timeout option",
|
||||
Colors.YELLOW,
|
||||
)
|
||||
print_colored("Warning: pytest-timeout not installed, skipping timeout option", Colors.YELLOW)
|
||||
|
||||
print_colored("Running pytest with command:", Colors.YELLOW)
|
||||
print(f" {' '.join(cmd)}")
|
||||
@ -438,15 +362,15 @@ def run_tests(
|
||||
tree = ET.parse(junit_xml)
|
||||
root = tree.getroot()
|
||||
|
||||
if root.tag == "testsuites":
|
||||
suites = root.findall("testsuite")
|
||||
if root.tag == 'testsuites':
|
||||
suites = root.findall('testsuite')
|
||||
else:
|
||||
suites = [root]
|
||||
|
||||
total = sum(int(s.get("tests", 0)) for s in suites)
|
||||
failures = sum(int(s.get("failures", 0)) for s in suites)
|
||||
errors = sum(int(s.get("errors", 0)) for s in suites)
|
||||
skipped = sum(int(s.get("skipped", 0)) for s in suites)
|
||||
total = sum(int(s.get('tests', 0)) for s in suites)
|
||||
failures = sum(int(s.get('failures', 0)) for s in suites)
|
||||
errors = sum(int(s.get('errors', 0)) for s in suites)
|
||||
skipped = sum(int(s.get('skipped', 0)) for s in suites)
|
||||
passed = total - failures - errors - skipped
|
||||
|
||||
print(f" Total: {total}")
|
||||
@ -460,11 +384,7 @@ def run_tests(
|
||||
|
||||
if total > 0:
|
||||
pass_rate = (passed / total) * 100
|
||||
color = (
|
||||
Colors.GREEN
|
||||
if pass_rate >= 80
|
||||
else Colors.YELLOW if pass_rate >= 60 else Colors.RED
|
||||
)
|
||||
color = Colors.GREEN if pass_rate >= 80 else Colors.YELLOW if pass_rate >= 60 else Colors.RED
|
||||
print_colored(f" Pass Rate: {pass_rate:.1f}%", color)
|
||||
else:
|
||||
print_colored("No test results found!", Colors.RED)
|
||||
@ -474,24 +394,16 @@ def run_tests(
|
||||
|
||||
return result.returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run LLM Translation Tests")
|
||||
parser.add_argument(
|
||||
"--test-path", default="tests/llm_translation/", help="Path to test directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--junit-xml",
|
||||
default="test-results/junit.xml",
|
||||
help="Path for JUnit XML output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report",
|
||||
default="test-results/llm_translation_report.md",
|
||||
help="Path for markdown report",
|
||||
)
|
||||
parser.add_argument("--test-path", default="tests/llm_translation/",
|
||||
help="Path to test directory")
|
||||
parser.add_argument("--junit-xml", default="test-results/junit.xml",
|
||||
help="Path for JUnit XML output")
|
||||
parser.add_argument("--report", default="test-results/llm_translation_report.md",
|
||||
help="Path for markdown report")
|
||||
parser.add_argument("--tag", help="Git tag or version")
|
||||
parser.add_argument("--commit", help="Git commit SHA")
|
||||
|
||||
@ -500,9 +412,8 @@ if __name__ == "__main__":
|
||||
# Get git info if not provided
|
||||
if not args.commit:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"], capture_output=True, text=True
|
||||
)
|
||||
result = subprocess.run(["git", "rev-parse", "HEAD"],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
args.commit = result.stdout.strip()
|
||||
except:
|
||||
@ -510,11 +421,8 @@ if __name__ == "__main__":
|
||||
|
||||
if not args.tag:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "describe", "--tags", "--abbrev=0"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
args.tag = result.stdout.strip()
|
||||
except:
|
||||
@ -525,7 +433,7 @@ if __name__ == "__main__":
|
||||
junit_xml=args.junit_xml,
|
||||
report_path=args.report,
|
||||
tag=args.tag,
|
||||
commit=args.commit,
|
||||
commit=args.commit
|
||||
)
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
@ -9,6 +9,7 @@ from pathlib import Path
|
||||
|
||||
import testing.postgresql
|
||||
|
||||
|
||||
DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE)
|
||||
DEFAULT_BASE_BRANCH = "litellm_internal_staging"
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
import os
|
||||
|
||||
|
||||
# Define the list of models to benchmark
|
||||
# select any LLM listed here: https://docs.litellm.ai/docs/providers
|
||||
models = ["gpt-3.5-turbo", "claude-2"]
|
||||
|
||||
@ -22,6 +22,7 @@ from fastapi import FastAPI, HTTPException, Header, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Braintrust Prompt Wrapper",
|
||||
description="Wrapper server for Braintrust prompts to work with LiteLLM",
|
||||
|
||||
@ -518,7 +518,8 @@ if __name__ == "__main__":
|
||||
print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply")
|
||||
print("=" * 80)
|
||||
print("\nExample curl command:")
|
||||
print(f"""
|
||||
print(
|
||||
f"""
|
||||
curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
|
||||
-H "Authorization: Bearer {bearer_token}" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
@ -532,7 +533,8 @@ curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
|
||||
}}
|
||||
]
|
||||
}}'
|
||||
""")
|
||||
"""
|
||||
)
|
||||
print("=" * 80)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
@ -34,7 +34,8 @@ async def check_view_exists(): # noqa: PLR0915
|
||||
print("LiteLLM_VerificationTokenView Exists!") # noqa
|
||||
except Exception:
|
||||
# If an error occurs, the view does not exist, so create it
|
||||
await db.execute_raw("""
|
||||
await db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
@ -44,7 +45,8 @@ async def check_view_exists(): # noqa: PLR0915
|
||||
t.rpm_limit AS team_rpm_limit
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
print("LiteLLM_VerificationTokenView Created!") # noqa
|
||||
|
||||
|
||||
@ -14,74 +14,53 @@ from litellm.types.utils import StandardCallbackDynamicParams
|
||||
class EnterpriseCallbackControls:
|
||||
@staticmethod
|
||||
def is_callback_disabled_dynamically(
|
||||
callback: litellm.CALLBACK_TYPES,
|
||||
litellm_params: dict,
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
|
||||
callback: litellm.CALLBACK_TYPES,
|
||||
litellm_params: dict,
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
|
||||
|
||||
Args:
|
||||
callback: The callback to check (can be string, CustomLogger instance, or callable)
|
||||
litellm_params: Parameters containing proxy server request info
|
||||
Args:
|
||||
callback: The callback to check (can be string, CustomLogger instance, or callable)
|
||||
litellm_params: Parameters containing proxy server request info
|
||||
|
||||
Returns:
|
||||
bool: True if the callback should be disabled, False otherwise
|
||||
"""
|
||||
from litellm.litellm_core_utils.custom_logger_registry import (
|
||||
CustomLoggerRegistry,
|
||||
)
|
||||
|
||||
try:
|
||||
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(
|
||||
litellm_params, standard_callback_dynamic_params
|
||||
Returns:
|
||||
bool: True if the callback should be disabled, False otherwise
|
||||
"""
|
||||
from litellm.litellm_core_utils.custom_logger_registry import (
|
||||
CustomLoggerRegistry,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}"
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}"
|
||||
)
|
||||
if disabled_callbacks is not None:
|
||||
#########################################################
|
||||
# premium user check
|
||||
#########################################################
|
||||
if (
|
||||
not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling()
|
||||
):
|
||||
return False
|
||||
#########################################################
|
||||
if isinstance(callback, str):
|
||||
if callback.lower() in disabled_callbacks:
|
||||
verbose_logger.debug(
|
||||
f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
|
||||
)
|
||||
return True
|
||||
elif isinstance(callback, CustomLogger):
|
||||
# get the string name of the callback
|
||||
callback_str = (
|
||||
CustomLoggerRegistry.get_callback_str_from_class_type(
|
||||
callback.__class__
|
||||
)
|
||||
)
|
||||
if (
|
||||
callback_str is not None
|
||||
and callback_str.lower() in disabled_callbacks
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error checking disabled callbacks header: {str(e)}")
|
||||
return False
|
||||
|
||||
try:
|
||||
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
|
||||
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
|
||||
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
|
||||
if disabled_callbacks is not None:
|
||||
#########################################################
|
||||
# premium user check
|
||||
#########################################################
|
||||
if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling():
|
||||
return False
|
||||
#########################################################
|
||||
if isinstance(callback, str):
|
||||
if callback.lower() in disabled_callbacks:
|
||||
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
|
||||
return True
|
||||
elif isinstance(callback, CustomLogger):
|
||||
# get the string name of the callback
|
||||
callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__)
|
||||
if callback_str is not None and callback_str.lower() in disabled_callbacks:
|
||||
verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error checking disabled callbacks header: {str(e)}"
|
||||
)
|
||||
return False
|
||||
@staticmethod
|
||||
def get_disabled_callbacks(
|
||||
litellm_params: dict,
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
) -> Optional[List[str]]:
|
||||
def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
|
||||
"""
|
||||
Get the disabled callbacks from the standard callback dynamic params.
|
||||
"""
|
||||
@ -92,21 +71,15 @@ class EnterpriseCallbackControls:
|
||||
request_headers = get_proxy_server_request_headers(litellm_params)
|
||||
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
|
||||
if disabled_callbacks is not None:
|
||||
disabled_callbacks = set(
|
||||
[cb.strip().lower() for cb in disabled_callbacks.split(",")]
|
||||
)
|
||||
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
|
||||
return list(disabled_callbacks)
|
||||
|
||||
|
||||
#########################################################
|
||||
# check if disabled via request body
|
||||
#########################################################
|
||||
if (
|
||||
standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
|
||||
is not None
|
||||
):
|
||||
return standard_callback_dynamic_params.get(
|
||||
"litellm_disabled_callbacks", None
|
||||
)
|
||||
if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
|
||||
return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
|
||||
|
||||
return None
|
||||
|
||||
@ -117,14 +90,10 @@ class EnterpriseCallbackControls:
|
||||
|
||||
# Check if admin has disabled this feature
|
||||
if litellm.allow_dynamic_callback_disabling is not True:
|
||||
verbose_logger.debug(
|
||||
"Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling"
|
||||
)
|
||||
verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling")
|
||||
return False
|
||||
|
||||
if premium_user:
|
||||
return True
|
||||
verbose_logger.warning(
|
||||
f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
|
||||
return False
|
||||
@ -349,10 +349,8 @@ class BaseEmailLogger(CustomLogger):
|
||||
)
|
||||
|
||||
# Calculate percentage and alert threshold
|
||||
percentage = (
|
||||
threshold_pct
|
||||
if threshold_pct is not None
|
||||
else int(EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100)
|
||||
percentage = threshold_pct if threshold_pct is not None else int(
|
||||
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100
|
||||
)
|
||||
threshold_fraction = percentage / 100.0
|
||||
alert_threshold_str = (
|
||||
@ -611,7 +609,9 @@ class BaseEmailLogger(CustomLogger):
|
||||
continue
|
||||
|
||||
_id = user_info.token or user_info.user_id or "default_id"
|
||||
_cache_key = f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
|
||||
_cache_key = (
|
||||
f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
|
||||
)
|
||||
|
||||
result = await _cache.async_get_cache(key=_cache_key)
|
||||
if result is not None:
|
||||
@ -630,9 +630,7 @@ class BaseEmailLogger(CustomLogger):
|
||||
continue
|
||||
recipient_emails = list(set(emails))
|
||||
|
||||
event_message = (
|
||||
f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
|
||||
)
|
||||
event_message = f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
|
||||
webhook_event = WebhookEvent(
|
||||
event="max_budget_alert",
|
||||
event_message=event_message,
|
||||
|
||||
@ -15,6 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||
|
||||
from .base_email import BaseEmailLogger
|
||||
|
||||
|
||||
SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send"
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""
|
||||
This is the litellm SMTP email integration
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Enterprise specific logging utils
|
||||
"""
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata
|
||||
|
||||
|
||||
|
||||
@ -153,11 +153,11 @@ async def get_audit_logs(
|
||||
|
||||
# Return paginated response
|
||||
return PaginatedAuditLogResponse(
|
||||
audit_logs=(
|
||||
[AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs]
|
||||
if audit_logs
|
||||
else []
|
||||
),
|
||||
audit_logs=[
|
||||
AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs
|
||||
]
|
||||
if audit_logs
|
||||
else [],
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
||||
@ -53,9 +53,7 @@ class CheckBatchCost:
|
||||
"user_api_key_alias": getattr(user_row, "user_alias", None),
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}"
|
||||
)
|
||||
verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}")
|
||||
return {}
|
||||
|
||||
async def _cleanup_stale_managed_objects(self) -> None:
|
||||
@ -64,22 +62,11 @@ class CheckBatchCost:
|
||||
in non-terminal states as 'stale_expired'. These will never complete and
|
||||
should not be polled.
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(
|
||||
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
|
||||
)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
|
||||
result = await self.prisma_client.db.litellm_managedobjecttable.update_many(
|
||||
where={
|
||||
"file_purpose": "batch",
|
||||
"status": {
|
||||
"not_in": [
|
||||
"completed",
|
||||
"complete",
|
||||
"failed",
|
||||
"expired",
|
||||
"cancelled",
|
||||
"stale_expired",
|
||||
]
|
||||
},
|
||||
"status": {"not_in": ["completed", "complete", "failed", "expired", "cancelled", "stale_expired"]},
|
||||
"created_at": {"lt": cutoff},
|
||||
},
|
||||
data={"status": "stale_expired"},
|
||||
@ -133,12 +120,9 @@ class CheckBatchCost:
|
||||
|
||||
try:
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
prom_logger = PrometheusLogger.get_instance()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"CheckBatchCost: could not get Prometheus logger: {e}"
|
||||
)
|
||||
verbose_proxy_logger.error(f"CheckBatchCost: could not get Prometheus logger: {e}")
|
||||
prom_logger = None
|
||||
|
||||
processed_models: List[Tuple[Optional[str], Optional[str]]] = []
|
||||
@ -177,11 +161,7 @@ class CheckBatchCost:
|
||||
order={"created_at": "asc"},
|
||||
)
|
||||
except Exception as query_err:
|
||||
if (
|
||||
"batch_processed" not in str(query_err).lower()
|
||||
and "unknown column" not in str(query_err).lower()
|
||||
and "does not exist" not in str(query_err).lower()
|
||||
):
|
||||
if "batch_processed" not in str(query_err).lower() and "unknown column" not in str(query_err).lower() and "does not exist" not in str(query_err).lower():
|
||||
raise
|
||||
# Permanent schema gap — cache the result so future cycles skip straight to fallback
|
||||
self._has_batch_processed_column = False
|
||||
@ -236,13 +216,14 @@ class CheckBatchCost:
|
||||
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
|
||||
)
|
||||
if prom_logger:
|
||||
prom_logger.record_check_batch_cost_error(
|
||||
"provider_retrieval_error"
|
||||
)
|
||||
prom_logger.record_check_batch_cost_error("provider_retrieval_error")
|
||||
continue
|
||||
|
||||
## RETRIEVE THE BATCH JOB OUTPUT FILE
|
||||
if response.status == "completed" and response.output_file_id is not None:
|
||||
if (
|
||||
response.status == "completed"
|
||||
and response.output_file_id is not None
|
||||
):
|
||||
verbose_proxy_logger.info(
|
||||
f"Batch ID: {batch_id} is complete, tracking cost and usage"
|
||||
)
|
||||
@ -269,25 +250,20 @@ class CheckBatchCost:
|
||||
decoded = _is_base64_encoded_unified_file_id(raw_output_file_id)
|
||||
if decoded:
|
||||
try:
|
||||
raw_output_file_id = decoded.split("llm_output_file_id,")[
|
||||
1
|
||||
].split(";")[0]
|
||||
raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0]
|
||||
except (IndexError, AttributeError):
|
||||
pass
|
||||
|
||||
credentials = (
|
||||
self.llm_router.get_deployment_credentials_with_provider(model_id)
|
||||
or {}
|
||||
)
|
||||
credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {}
|
||||
_file_content = await afile_content(
|
||||
file_id=raw_output_file_id,
|
||||
**credentials,
|
||||
)
|
||||
|
||||
# Access content - handle both direct attribute and method call
|
||||
if hasattr(_file_content, "content"):
|
||||
if hasattr(_file_content, 'content'):
|
||||
content_bytes = _file_content.content # type: ignore[union-attr]
|
||||
elif hasattr(_file_content, "read"):
|
||||
elif hasattr(_file_content, 'read'):
|
||||
content_bytes = await _file_content.read() # type: ignore[misc]
|
||||
else:
|
||||
content_bytes = _file_content # type: ignore[assignment]
|
||||
@ -314,9 +290,7 @@ class CheckBatchCost:
|
||||
f"Skipping job {unified_object_id} because it is not a valid deployment info"
|
||||
)
|
||||
if prom_logger:
|
||||
prom_logger.record_check_batch_cost_error(
|
||||
"deployment_not_found"
|
||||
)
|
||||
prom_logger.record_check_batch_cost_error("deployment_not_found")
|
||||
continue
|
||||
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
|
||||
litellm_model_name = deployment_info.litellm_params.model
|
||||
@ -328,11 +302,7 @@ class CheckBatchCost:
|
||||
|
||||
# Pass deployment model_info so custom batch pricing
|
||||
# (input_cost_per_token_batches etc.) is used for cost calc
|
||||
deployment_model_info = (
|
||||
deployment_info.model_info.model_dump()
|
||||
if deployment_info.model_info
|
||||
else {}
|
||||
)
|
||||
deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}
|
||||
batch_cost, batch_usage, batch_models = (
|
||||
await calculate_batch_cost_and_usage(
|
||||
file_content_dictionary=file_content_as_dict,
|
||||
@ -379,9 +349,7 @@ class CheckBatchCost:
|
||||
|
||||
# Record batch duration (completed_at - created_at)
|
||||
if prom_logger and response.completed_at and response.created_at:
|
||||
duration_seconds = float(
|
||||
response.completed_at - response.created_at
|
||||
)
|
||||
duration_seconds = float(response.completed_at - response.created_at)
|
||||
if duration_seconds >= 0:
|
||||
prom_logger.record_managed_batch_duration(
|
||||
duration_seconds=duration_seconds,
|
||||
@ -390,9 +358,7 @@ class CheckBatchCost:
|
||||
)
|
||||
|
||||
# Track this job for the final metrics summary
|
||||
processed_models.append(
|
||||
(model_name, str(llm_provider) if llm_provider else None)
|
||||
)
|
||||
processed_models.append((model_name, str(llm_provider) if llm_provider else None))
|
||||
|
||||
# mark the job as complete
|
||||
try:
|
||||
|
||||
@ -33,7 +33,9 @@ class CheckResponsesCost:
|
||||
self.prisma_client: PrismaClient = prisma_client
|
||||
self.llm_router: Router = llm_router
|
||||
|
||||
async def _expire_stale_rows(self, cutoff: datetime, batch_size: int) -> int:
|
||||
async def _expire_stale_rows(
|
||||
self, cutoff: datetime, batch_size: int
|
||||
) -> int:
|
||||
"""Execute the bounded UPDATE that marks stale rows as 'stale_expired'.
|
||||
|
||||
Isolated so it can be swapped / mocked in tests without touching the
|
||||
@ -72,9 +74,7 @@ class CheckResponsesCost:
|
||||
rows per invocation to avoid overwhelming the DB when there is a large
|
||||
backlog.
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(
|
||||
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
|
||||
)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
|
||||
result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE)
|
||||
if result > 0:
|
||||
verbose_proxy_logger.warning(
|
||||
@ -122,9 +122,7 @@ class CheckResponsesCost:
|
||||
model_name = stored_response.get("model", None)
|
||||
|
||||
# Decrypt the response ID
|
||||
responses_id_security, _, _ = (
|
||||
ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
|
||||
)
|
||||
responses_id_security, _, _ = ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
|
||||
|
||||
# Prepare metadata with model information for cost tracking
|
||||
litellm_metadata = {
|
||||
@ -134,9 +132,7 @@ class CheckResponsesCost:
|
||||
# Add model information if available
|
||||
if model_name:
|
||||
litellm_metadata["model"] = model_name
|
||||
litellm_metadata["model_group"] = (
|
||||
model_name # Use same value for model_group
|
||||
)
|
||||
litellm_metadata["model_group"] = model_name # Use same value for model_group
|
||||
|
||||
response = await litellm.aget_responses(
|
||||
response_id=responses_id_security,
|
||||
@ -175,3 +171,4 @@ class CheckResponsesCost:
|
||||
verbose_proxy_logger.info(
|
||||
f"Marked {len(completed_jobs)} response jobs as completed"
|
||||
)
|
||||
|
||||
|
||||
@ -120,7 +120,9 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
VectorStoreCreateResponse from the provider
|
||||
"""
|
||||
# Use the router to create the vector store
|
||||
response = await llm_router.avector_store_create(model=model, **request_data)
|
||||
response = await llm_router.avector_store_create(
|
||||
model=model, **request_data
|
||||
)
|
||||
return response
|
||||
|
||||
# ============================================================================
|
||||
@ -274,7 +276,9 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
"""
|
||||
vector_store_id = cast(Optional[str], data.get("vector_store_id"))
|
||||
is_unified_id = (
|
||||
is_base64_encoded_unified_id(vector_store_id) if vector_store_id else False
|
||||
is_base64_encoded_unified_id(vector_store_id)
|
||||
if vector_store_id
|
||||
else False
|
||||
)
|
||||
|
||||
if is_unified_id and vector_store_id:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Enterprise internal user management endpoints
|
||||
"""
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
@ -39,23 +39,15 @@ class EmailEvent(str, enum.Enum):
|
||||
soft_budget_crossed = "Soft Budget Crossed"
|
||||
max_budget_alert = "Max Budget Alert"
|
||||
|
||||
|
||||
class EmailEventSettings(BaseModel):
|
||||
event: EmailEvent
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EmailEventSettingsUpdateRequest(BaseModel):
|
||||
settings: List[EmailEventSettings]
|
||||
|
||||
|
||||
class EmailEventSettingsResponse(BaseModel):
|
||||
settings: List[EmailEventSettings]
|
||||
|
||||
|
||||
class DefaultEmailSettings(BaseModel):
|
||||
"""Default settings for email events"""
|
||||
|
||||
settings: Dict[EmailEvent, bool] = Field(
|
||||
default_factory=lambda: {
|
||||
EmailEvent.virtual_key_created: True, # On by default
|
||||
@ -65,11 +57,9 @@ class DefaultEmailSettings(BaseModel):
|
||||
EmailEvent.max_budget_alert: True, # On by default
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, bool]:
|
||||
"""Convert to dictionary with string keys for storage"""
|
||||
return {event.value: enabled for event, enabled in self.settings.items()}
|
||||
|
||||
@classmethod
|
||||
def get_defaults(cls) -> Dict[str, bool]:
|
||||
"""Get the default settings as a dictionary with string keys"""
|
||||
|
||||
@ -6,6 +6,7 @@ Always uses fastuuid for performance.
|
||||
|
||||
import fastuuid as _uuid # type: ignore
|
||||
|
||||
|
||||
# Expose a module-like alias so callers can use: uuid.uuid4()
|
||||
uuid = _uuid
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Dict, Optional
|
||||
|
||||
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
|
||||
|
||||
|
||||
# HTTP status code -> Anthropic error type
|
||||
# Source: https://docs.anthropic.com/en/api/errors
|
||||
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
|
||||
# Known Anthropic error types
|
||||
# Source: https://docs.anthropic.com/en/api/errors
|
||||
AnthropicErrorType = Literal[
|
||||
|
||||
@ -97,7 +97,7 @@ def _build_reasoning_item(
|
||||
|
||||
|
||||
def _reasoning_item_to_response_input(
|
||||
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]],
|
||||
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a stored ChatCompletionReasoningItem back to a Responses API input item."""
|
||||
r_input: Dict[str, Any] = {
|
||||
|
||||
@ -5,6 +5,7 @@ Auto-detect content type per message: code, JSON, or text.
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
_CODE_KEYWORDS = re.compile(
|
||||
r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b"
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union
|
||||
|
||||
|
||||
FileContentProvider = Literal[
|
||||
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
|
||||
]
|
||||
|
||||
@ -18,7 +18,7 @@ else:
|
||||
|
||||
|
||||
def process_slack_alerting_variables(
|
||||
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]],
|
||||
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
|
||||
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
|
||||
"""
|
||||
process alert_to_webhook_url
|
||||
|
||||
@ -9,6 +9,7 @@ import polars as pl
|
||||
|
||||
from .schema import FOCUS_NORMALIZED_SCHEMA
|
||||
|
||||
|
||||
_TAG_KEYS = (
|
||||
"team_id",
|
||||
"team_alias",
|
||||
|
||||
@ -105,7 +105,7 @@ def _remove_nulls(x: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def get_traces_and_spans_from_payload(
|
||||
payload: List[Dict[str, Any]],
|
||||
payload: List[Dict[str, Any]]
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Separate traces and spans from payload.
|
||||
|
||||
@ -994,8 +994,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if turn_off_message_logging is True:
|
||||
_metadata["raw_request"] = "redacted by litellm. \
|
||||
_metadata["raw_request"] = (
|
||||
"redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
)
|
||||
else:
|
||||
curl_command = self._get_request_curl_command(
|
||||
api_base=additional_args.get("api_base", ""),
|
||||
@ -1029,8 +1031,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
_metadata["raw_request"] = "Unable to Log \
|
||||
raw request: {}".format(str(e))
|
||||
_metadata["raw_request"] = (
|
||||
"Unable to Log \
|
||||
raw request: {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
|
||||
try:
|
||||
self.logger_fn(
|
||||
|
||||
@ -5533,7 +5533,9 @@ def default_response_schema_prompt(response_schema: dict) -> str:
|
||||
prompt_str = """Use this JSON schema:
|
||||
```json
|
||||
{}
|
||||
```""".format(response_schema)
|
||||
```""".format(
|
||||
response_schema
|
||||
)
|
||||
return prompt_str
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE parsing helpers (module-level to keep the class lean)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -16,6 +16,7 @@ from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
|
||||
|
||||
BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1"
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from typing import Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-built response templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -22,6 +22,7 @@ from litellm.types.utils import all_litellm_params
|
||||
|
||||
from ..common_utils import ElevenLabsException
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
@ -55,7 +55,7 @@ def _convert_image_to_gemini_format(image_file) -> Dict[str, str]:
|
||||
|
||||
|
||||
def _usage_video_resolution_from_parameters(
|
||||
parameters: Dict[str, Any],
|
||||
parameters: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
"""Normalize Veo ``parameters.resolution`` for usage and cost tracking."""
|
||||
res = parameters.get("resolution")
|
||||
|
||||
@ -201,7 +201,7 @@ class BaseOpenAILLM:
|
||||
|
||||
@staticmethod
|
||||
def get_openai_client_initialization_param_fields(
|
||||
client_type: Literal["openai", "azure"],
|
||||
client_type: Literal["openai", "azure"]
|
||||
) -> Tuple[str, ...]:
|
||||
"""Returns a tuple of fields that are used to initialize the OpenAI client"""
|
||||
if client_type == "openai":
|
||||
|
||||
@ -49,6 +49,7 @@ from litellm.types.utils import (
|
||||
)
|
||||
from litellm.llms.openrouter.common_utils import OpenRouterException
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
|
||||
@ -207,7 +207,7 @@ def resolve_resource_group(sources: List[Source]) -> Optional[str]:
|
||||
|
||||
|
||||
def _parse_service_key_once(
|
||||
service_key: Optional[Union[str, dict]],
|
||||
service_key: Optional[Union[str, dict]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Pre-parse service_key if it's a string to avoid repeated JSON parsing.
|
||||
|
||||
@ -14,6 +14,7 @@ from ...openai_like.chat.transformation import OpenAIGPTConfig
|
||||
|
||||
from ..utils import SnowflakeBaseConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from ..gemini.transformation import (
|
||||
|
||||
|
||||
def get_first_continuous_block_idx(
|
||||
filtered_messages: List[Tuple[int, AllMessageValues]], # (idx, message)
|
||||
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
|
||||
) -> int:
|
||||
"""
|
||||
Find the array index that ends the first continuous sequence of message blocks.
|
||||
|
||||
@ -632,14 +632,16 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||
contents.append(ContentType(role="user", parts=tool_call_responses))
|
||||
|
||||
if len(contents) == 0:
|
||||
verbose_logger.warning("""
|
||||
verbose_logger.warning(
|
||||
"""
|
||||
No contents in messages. Contents are required. See
|
||||
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
|
||||
If the original request did not comply to OpenAI API requirements it should have failed by now,
|
||||
but LiteLLM does not check for missing messages.
|
||||
Setting an empty content to prevent an 400 error.
|
||||
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
|
||||
""")
|
||||
"""
|
||||
)
|
||||
contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
|
||||
return contents
|
||||
except Exception as e:
|
||||
|
||||
@ -139,7 +139,7 @@ class VertexTextToSpeechAPI(VertexLLM):
|
||||
########## End of logging ############
|
||||
####### Send the request ###################
|
||||
if _is_async is True:
|
||||
return self.async_audio_speech( # type: ignore
|
||||
return self.async_audio_speech( # type:ignore
|
||||
logging_obj=logging_obj, url=url, headers=headers, request=request
|
||||
)
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
@ -324,7 +324,7 @@ class CustomOpenAPISpec:
|
||||
|
||||
@staticmethod
|
||||
def add_chat_completion_request_schema(
|
||||
openapi_schema: Dict[str, Any],
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
|
||||
@ -380,7 +380,7 @@ class CustomOpenAPISpec:
|
||||
|
||||
@staticmethod
|
||||
def add_responses_api_request_schema(
|
||||
openapi_schema: Dict[str, Any],
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
|
||||
@ -410,7 +410,7 @@ class CustomOpenAPISpec:
|
||||
|
||||
@staticmethod
|
||||
def add_llm_api_request_schema_body(
|
||||
openapi_schema: Dict[str, Any],
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add LLM API request schema bodies to OpenAPI specification for documentation.
|
||||
|
||||
@ -257,7 +257,7 @@ async def get_form_data(request: Request) -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def convert_upload_files_to_file_data(
|
||||
form_data: Dict[str, Any],
|
||||
form_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert FastAPI UploadFile objects to file data tuples for litellm.
|
||||
|
||||
@ -34,7 +34,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
|
||||
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
|
||||
raise
|
||||
# If an error occurs, the view does not exist, so create it
|
||||
await db.execute_raw("""
|
||||
await db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
@ -46,7 +47,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
|
||||
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
verbose_logger.debug("LiteLLM_VerificationTokenView Created!")
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ every text fragment.
|
||||
|
||||
from typing import Any, Callable, Dict, FrozenSet, Iterator, List
|
||||
|
||||
|
||||
# Call types whose body carries free-form chat / prompt text that
|
||||
# text-content guardrails (banned keywords, content moderation, secret
|
||||
# detection, …) should inspect. The proxy ingress passes ``route_type``
|
||||
|
||||
@ -4,6 +4,7 @@ from litellm.types.guardrails import SupportedGuardrailIntegrations
|
||||
|
||||
from .akto import AktoGuardrail
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||
|
||||
|
||||
@ -529,7 +529,7 @@ async def _update_existing_team_model_assignment(
|
||||
"""
|
||||
|
||||
def _get_team_public_model_name(
|
||||
model_info: Optional[Union[dict, str]],
|
||||
model_info: Optional[Union[dict, str]]
|
||||
) -> Optional[str]:
|
||||
if isinstance(model_info, dict):
|
||||
value = model_info.get("team_public_model_name")
|
||||
|
||||
@ -4347,7 +4347,9 @@ async def list_team(
|
||||
except Exception as e:
|
||||
team_exception = """Invalid team object for team_id: {}. team_object={}.
|
||||
Error: {}
|
||||
""".format(team.team_id, team.model_dump(), str(e))
|
||||
""".format(
|
||||
team.team_id, team.model_dump(), str(e)
|
||||
)
|
||||
verbose_proxy_logger.exception(team_exception)
|
||||
continue
|
||||
# Sort the responses by team_alias
|
||||
|
||||
@ -18,6 +18,7 @@ from litellm.litellm_core_utils.litellm_logging import (
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
|
||||
|
||||
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
|
||||
"POST /v0/agents": "cursor:agent:create",
|
||||
"GET /v0/agents": "cursor:agent:list",
|
||||
|
||||
@ -292,7 +292,9 @@ class ProxyInitializationHelpers:
|
||||
_endpoint_str = (
|
||||
f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\"
|
||||
)
|
||||
curl_command = _endpoint_str + """
|
||||
curl_command = (
|
||||
_endpoint_str
|
||||
+ """
|
||||
--header 'Content-Type: application/json' \\
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
@ -305,6 +307,7 @@ class ProxyInitializationHelpers:
|
||||
}'
|
||||
\n
|
||||
"""
|
||||
)
|
||||
print() # noqa
|
||||
print( # noqa
|
||||
'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n'
|
||||
@ -380,9 +383,11 @@ class ProxyInitializationHelpers:
|
||||
with open(os.devnull, "w") as devnull:
|
||||
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
||||
except Exception as e:
|
||||
print(f"""
|
||||
print( # noqa
|
||||
f"""
|
||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||
""") # noqa # noqa
|
||||
"""
|
||||
) # noqa
|
||||
|
||||
@staticmethod
|
||||
def _is_port_in_use(port):
|
||||
|
||||
@ -2688,9 +2688,11 @@ def run_ollama_serve():
|
||||
with open(os.devnull, "w") as devnull:
|
||||
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"""
|
||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _get_process_rss_mb() -> Optional[float]:
|
||||
|
||||
@ -3184,14 +3184,16 @@ async def provider_budgets() -> ProviderBudgetResponse:
|
||||
async def get_spend_by_tags(
|
||||
prisma_client: PrismaClient, start_date=None, end_date=None
|
||||
):
|
||||
response = await prisma_client.db.query_raw("""
|
||||
response = await prisma_client.db.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||
COUNT(*) AS log_count,
|
||||
SUM(spend) AS total_spend
|
||||
FROM "LiteLLM_SpendLogs"
|
||||
GROUP BY individual_request_tag;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -2712,7 +2712,8 @@ class PrismaClient:
|
||||
required_view = "LiteLLM_VerificationTokenView"
|
||||
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
||||
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
|
||||
ret = await self.db.query_raw(f"""
|
||||
ret = await self.db.query_raw(
|
||||
f"""
|
||||
WITH existing_views AS (
|
||||
SELECT viewname
|
||||
FROM pg_views
|
||||
@ -2724,7 +2725,8 @@ class PrismaClient:
|
||||
(SELECT COUNT(*) FROM existing_views) AS view_count,
|
||||
ARRAY_AGG(viewname) AS view_names
|
||||
FROM existing_views
|
||||
""")
|
||||
"""
|
||||
)
|
||||
expected_total_views = len(expected_views)
|
||||
if ret[0]["view_count"] == expected_total_views:
|
||||
verbose_proxy_logger.info("All necessary views exist!")
|
||||
@ -2733,7 +2735,8 @@ class PrismaClient:
|
||||
## check if required view exists ##
|
||||
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
|
||||
await self.health_check() # make sure we can connect to db
|
||||
await self.db.execute_raw("""
|
||||
await self.db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
@ -2743,7 +2746,8 @@ class PrismaClient:
|
||||
t.rpm_limit AS team_rpm_limit
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"LiteLLM_VerificationTokenView Created in DB!"
|
||||
|
||||
@ -826,7 +826,7 @@ class Router:
|
||||
|
||||
@staticmethod
|
||||
def _normalize_strategy(
|
||||
strategy: Union[RoutingStrategy, str, None],
|
||||
strategy: Union[RoutingStrategy, str, None]
|
||||
) -> Optional[str]:
|
||||
if strategy is None:
|
||||
return None
|
||||
|
||||
@ -103,7 +103,7 @@ def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str
|
||||
|
||||
|
||||
def _recent_tool_results(
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
messages: Optional[List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract the current turn's tool result payloads from the request messages.
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ from litellm.router_strategy.adaptive_router.config import (
|
||||
TOOL_CALL_HISTORY_MAX,
|
||||
)
|
||||
|
||||
|
||||
# ---- Public types ---------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@ -10,11 +10,11 @@ This means you can use this with weighted-pick, lowest-latency, simple-shuffle,
|
||||
Example:
|
||||
```
|
||||
openai:
|
||||
budget_limit: 0.000000000001
|
||||
time_period: 1d
|
||||
budget_limit: 0.000000000001
|
||||
time_period: 1d
|
||||
anthropic:
|
||||
budget_limit: 100
|
||||
time_period: 7d
|
||||
budget_limit: 100
|
||||
time_period: 7d
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ class PatternUtils:
|
||||
|
||||
@staticmethod
|
||||
def sorted_patterns(
|
||||
patterns: Dict[str, List[Dict]],
|
||||
patterns: Dict[str, List[Dict]]
|
||||
) -> List[Tuple[str, List[Dict]]]:
|
||||
"""
|
||||
Cached property for patterns sorted by specificity.
|
||||
|
||||
@ -21,7 +21,7 @@ class VectorStoreFileRequestUtils:
|
||||
|
||||
@staticmethod
|
||||
def get_create_request_params(
|
||||
params: Dict[str, Any],
|
||||
params: Dict[str, Any]
|
||||
) -> VectorStoreFileCreateRequest:
|
||||
filtered = VectorStoreFileRequestUtils._filter_params(
|
||||
params=params, model=VectorStoreFileCreateRequest
|
||||
@ -37,7 +37,7 @@ class VectorStoreFileRequestUtils:
|
||||
|
||||
@staticmethod
|
||||
def get_update_request_params(
|
||||
params: Dict[str, Any],
|
||||
params: Dict[str, Any]
|
||||
) -> VectorStoreFileUpdateRequest:
|
||||
filtered = VectorStoreFileRequestUtils._filter_params(
|
||||
params=params, model=VectorStoreFileUpdateRequest
|
||||
|
||||
@ -76,11 +76,7 @@ utils = [
|
||||
# Not in Docker or PyPI proxy extra.
|
||||
"numpydoc==1.8.0",
|
||||
]
|
||||
# diskcache intentionally unpinned: CVE-2025-69872 (pickle RCE) has no
|
||||
# upstream fix. Stub kept so `pip install litellm[caching]` doesn't warn;
|
||||
# DiskCache loads diskcache lazily and forces JSONDisk for safety. See
|
||||
# litellm/caching/disk_cache.py.
|
||||
caching = []
|
||||
caching = ["diskcache==5.6.3"]
|
||||
semantic-router = [
|
||||
"semantic-router==0.1.12; python_version < '3.14'",
|
||||
"aurelio-sdk==0.0.19; python_version < '3.14'",
|
||||
@ -130,7 +126,7 @@ litellm-proxy = "litellm.proxy.client.cli:cli"
|
||||
dev = [
|
||||
"diff-cover==9.7.2",
|
||||
"flake8==7.3.0",
|
||||
"black==26.3.1",
|
||||
"black==24.10.0",
|
||||
"mypy==1.19.0",
|
||||
"pytest==9.0.3",
|
||||
"pytest-mock==3.15.1",
|
||||
|
||||
@ -35,7 +35,7 @@ import httpx
|
||||
class EvalCase:
|
||||
category: str
|
||||
prompt: str
|
||||
ideal: str # criteria the judge checks the response against
|
||||
ideal: str # criteria the judge checks the response against
|
||||
|
||||
|
||||
EVAL_CASES: List[EvalCase] = [
|
||||
@ -177,19 +177,14 @@ async def evaluate(
|
||||
async with httpx.AsyncClient() as client:
|
||||
for i, case in enumerate(EVAL_CASES, 1):
|
||||
print(f"\n[{i}/{len(EVAL_CASES)}] category={case.category}")
|
||||
print(
|
||||
f" prompt : {case.prompt[:80]}{'…' if len(case.prompt) > 80 else ''}"
|
||||
)
|
||||
print(f" prompt : {case.prompt[:80]}{'…' if len(case.prompt) > 80 else ''}")
|
||||
|
||||
session_id = f"eval-{uuid.uuid4()}"
|
||||
|
||||
# Round 1: single-turn real request — get the actual LLM response to judge.
|
||||
try:
|
||||
response, chosen = await _chat(
|
||||
client,
|
||||
proxy_url,
|
||||
api_key,
|
||||
router,
|
||||
client, proxy_url, api_key, router,
|
||||
[{"role": "user", "content": case.prompt}],
|
||||
session_id=session_id,
|
||||
)
|
||||
@ -199,25 +194,16 @@ async def evaluate(
|
||||
continue
|
||||
|
||||
print(f" model : {chosen or router}")
|
||||
print(
|
||||
f" response : {response[:120].replace(chr(10), ' ')}{'…' if len(response) > 120 else ''}"
|
||||
)
|
||||
print(f" response : {response[:120].replace(chr(10), ' ')}{'…' if len(response) > 120 else ''}")
|
||||
|
||||
# Judge the real response.
|
||||
judge_msgs = [
|
||||
{"role": "system", "content": JUDGE_SYSTEM},
|
||||
{
|
||||
"role": "user",
|
||||
"content": _judge_user(case.prompt, case.ideal, response),
|
||||
},
|
||||
{"role": "user", "content": _judge_user(case.prompt, case.ideal, response)},
|
||||
]
|
||||
try:
|
||||
verdict, _ = await _chat(
|
||||
client,
|
||||
proxy_url,
|
||||
api_key,
|
||||
judge_model,
|
||||
judge_msgs,
|
||||
client, proxy_url, api_key, judge_model, judge_msgs,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f" ERROR calling judge: {exc}", file=sys.stderr)
|
||||
@ -241,19 +227,15 @@ async def evaluate(
|
||||
# On PASS → satisfaction follow-up (+alpha). On FAIL → neutral (no signal).
|
||||
follow_up = SATISFY_FOLLOWUP if is_pass else NEUTRAL_FOLLOWUP
|
||||
bandit_msgs = [
|
||||
{"role": "user", "content": case.prompt},
|
||||
{"role": "user", "content": case.prompt},
|
||||
{"role": "assistant", "content": response},
|
||||
{"role": "user", "content": "ok continue"},
|
||||
{"role": "user", "content": "ok continue"},
|
||||
{"role": "assistant", "content": FAB_ASSISTANT},
|
||||
{"role": "user", "content": follow_up},
|
||||
{"role": "user", "content": follow_up},
|
||||
]
|
||||
try:
|
||||
await _chat(
|
||||
client,
|
||||
proxy_url,
|
||||
api_key,
|
||||
router,
|
||||
bandit_msgs,
|
||||
client, proxy_url, api_key, router, bandit_msgs,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@ -275,17 +257,11 @@ async def evaluate(
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description="Evaluate the adaptive router with LLM-as-judge."
|
||||
)
|
||||
ap.add_argument("--proxy-url", default="http://localhost:4000")
|
||||
ap.add_argument("--api-key", required=True, help="proxy API key")
|
||||
ap.add_argument(
|
||||
"--router", default="smart-cheap-router", help="adaptive router model name"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--judge-model", default="smart", help="model name for the judge (via proxy)"
|
||||
)
|
||||
ap = argparse.ArgumentParser(description="Evaluate the adaptive router with LLM-as-judge.")
|
||||
ap.add_argument("--proxy-url", default="http://localhost:4000")
|
||||
ap.add_argument("--api-key", required=True, help="proxy API key")
|
||||
ap.add_argument("--router", default="smart-cheap-router", help="adaptive router model name")
|
||||
ap.add_argument("--judge-model", default="smart", help="model name for the judge (via proxy)")
|
||||
args = ap.parse_args()
|
||||
|
||||
asyncio.run(evaluate(args.proxy_url, args.api_key, args.router, args.judge_model))
|
||||
|
||||
@ -72,8 +72,8 @@ PROMPTS: Dict[str, List[str]] = {
|
||||
# so that signals attribute to the right (type, model) bandit cell.
|
||||
SATISFY: Dict[str, str] = {
|
||||
"code_generation": "thanks, that works! now write me a python function that does the inverse",
|
||||
"factual_lookup": "perfect, thanks! who is the current prime minister?",
|
||||
"writing": "great, thanks! now write a follow-up email confirming attendance",
|
||||
"factual_lookup": "perfect, thanks! who is the current prime minister?",
|
||||
"writing": "great, thanks! now write a follow-up email confirming attendance",
|
||||
}
|
||||
|
||||
# Neutral follow-up — does not match any signal regex, does not move the bandit.
|
||||
@ -83,8 +83,8 @@ NEUTRAL_FOLLOWUP = "ok, noted"
|
||||
# Defaults: smart dominates code/writing; both are fine for factual_lookup.
|
||||
ORACLE: Dict[str, Dict[str, float]] = {
|
||||
"code_generation": {"smart": 0.92, "fast": 0.35},
|
||||
"factual_lookup": {"smart": 0.90, "fast": 0.85},
|
||||
"writing": {"smart": 0.85, "fast": 0.55},
|
||||
"factual_lookup": {"smart": 0.90, "fast": 0.85},
|
||||
"writing": {"smart": 0.85, "fast": 0.55},
|
||||
}
|
||||
|
||||
# Fabricated assistant turn — content doesn't matter for the hook, only the role.
|
||||
@ -94,11 +94,11 @@ FAB_ASSISTANT = "Got it. Working on that now."
|
||||
def _build_messages(prompt: str, last_user: str) -> List[Dict[str, str]]:
|
||||
"""5-message conversation that passes the SIGNAL_GATE_MIN_MESSAGES=4 gate."""
|
||||
return [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": FAB_ASSISTANT},
|
||||
{"role": "user", "content": "ok continue"},
|
||||
{"role": "user", "content": "ok continue"},
|
||||
{"role": "assistant", "content": FAB_ASSISTANT},
|
||||
{"role": "user", "content": last_user},
|
||||
{"role": "user", "content": last_user},
|
||||
]
|
||||
|
||||
|
||||
@ -155,11 +155,7 @@ async def _drive_one_session(
|
||||
#
|
||||
# Round 1: neutral follow-up → no signal fires, but we learn the pick.
|
||||
ok, chosen = await _send(
|
||||
client,
|
||||
proxy_url,
|
||||
api_key,
|
||||
router,
|
||||
session_id,
|
||||
client, proxy_url, api_key, router, session_id,
|
||||
_build_messages(prompt, NEUTRAL_FOLLOWUP),
|
||||
mock_response=FAB_ASSISTANT,
|
||||
)
|
||||
@ -175,15 +171,10 @@ async def _drive_one_session(
|
||||
# follow-up matches satisfaction → +alpha for (request_type, chosen).
|
||||
history = _build_messages(prompt, NEUTRAL_FOLLOWUP) + [
|
||||
{"role": "assistant", "content": FAB_ASSISTANT},
|
||||
{"role": "user", "content": follow_up},
|
||||
{"role": "user", "content": follow_up},
|
||||
]
|
||||
await _send(
|
||||
client,
|
||||
proxy_url,
|
||||
api_key,
|
||||
router,
|
||||
session_id,
|
||||
history,
|
||||
client, proxy_url, api_key, router, session_id, history,
|
||||
mock_response=FAB_ASSISTANT,
|
||||
)
|
||||
return chosen
|
||||
@ -192,22 +183,13 @@ async def _drive_one_session(
|
||||
async def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--proxy-url", default="http://localhost:4000")
|
||||
ap.add_argument(
|
||||
"--api-key", required=True, help="proxy key with /v1/chat/completions perms"
|
||||
)
|
||||
ap.add_argument("--router", default="smart-cheap-router")
|
||||
ap.add_argument("--rounds", type=int, default=100)
|
||||
ap.add_argument(
|
||||
"--rate",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="seconds between sessions; lower = faster",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--types",
|
||||
default="code_generation,factual_lookup,writing",
|
||||
help="comma-separated subset of request types to drive",
|
||||
)
|
||||
ap.add_argument("--api-key", required=True, help="proxy key with /v1/chat/completions perms")
|
||||
ap.add_argument("--router", default="smart-cheap-router")
|
||||
ap.add_argument("--rounds", type=int, default=100)
|
||||
ap.add_argument("--rate", type=float, default=0.5,
|
||||
help="seconds between sessions; lower = faster")
|
||||
ap.add_argument("--types", default="code_generation,factual_lookup,writing",
|
||||
help="comma-separated subset of request types to drive")
|
||||
args = ap.parse_args()
|
||||
|
||||
types = [t.strip() for t in args.types.split(",") if t.strip() in PROMPTS]
|
||||
@ -225,12 +207,7 @@ async def main() -> None:
|
||||
rt = random.choice(types)
|
||||
prompt = random.choice(PROMPTS[rt])
|
||||
chosen = await _drive_one_session(
|
||||
client,
|
||||
args.proxy_url,
|
||||
args.api_key,
|
||||
args.router,
|
||||
rt,
|
||||
prompt,
|
||||
client, args.proxy_url, args.api_key, args.router, rt, prompt,
|
||||
)
|
||||
if chosen:
|
||||
counts[(rt, chosen)] = counts.get((rt, chosen), 0) + 1
|
||||
|
||||
@ -8,6 +8,7 @@ import statistics
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
REQUEST_BODY = {
|
||||
"model": "db-openai-endpoint",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
|
||||
@ -42,7 +42,8 @@ from litellm.types.utils import CallTypes
|
||||
PROBLEMS = [
|
||||
{
|
||||
"id": "has_close_elements",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def has_close_elements(numbers: List[float], threshold: float) -> bool:
|
||||
@ -53,8 +54,10 @@ PROBLEMS = [
|
||||
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
|
||||
True
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True
|
||||
assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False
|
||||
assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True
|
||||
@ -62,11 +65,13 @@ PROBLEMS = [
|
||||
assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0], 2.0) == True
|
||||
assert has_close_elements([], 0.5) == False
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "separate_paren_groups",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def separate_paren_groups(paren_string: str) -> List[str]:
|
||||
@ -77,18 +82,22 @@ PROBLEMS = [
|
||||
>>> separate_paren_groups('( ) (( )) (( )( ))')
|
||||
['()', '(())', '(()())']
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert separate_paren_groups('(()()) ((())) () ((())()())') == ['(()())', '((()))', '()', '((())()())']
|
||||
assert separate_paren_groups('() (()) ((())) (((())))') == ['()', '(())', '((()))', '(((())))']
|
||||
assert separate_paren_groups('(()(()))') == ['(()(()))']
|
||||
assert separate_paren_groups('( ) (( )) (( )( ))') == ['()', '(())', '(()())']
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "truncate_number",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
def truncate_number(number: float) -> float:
|
||||
\"\"\"Given a positive floating point number, it can be decomposed into
|
||||
an integer part (largest integer smaller than given number) and decimals
|
||||
@ -97,17 +106,21 @@ PROBLEMS = [
|
||||
>>> truncate_number(3.5)
|
||||
0.5
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert truncate_number(3.5) == 0.5
|
||||
assert abs(truncate_number(1.33) - 0.33) < 1e-6
|
||||
assert abs(truncate_number(123.456) - 0.456) < 1e-6
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "below_zero",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def below_zero(operations: List[int]) -> bool:
|
||||
@ -119,8 +132,10 @@ PROBLEMS = [
|
||||
>>> below_zero([1, 2, -4, 5])
|
||||
True
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert below_zero([]) == False
|
||||
assert below_zero([1, 2, -3, 1, 2, -3]) == False
|
||||
assert below_zero([1, 2, -4, 5, 6]) == True
|
||||
@ -128,11 +143,13 @@ PROBLEMS = [
|
||||
assert below_zero([1, -1, 2, -2, 5, -5, 4, -5]) == True
|
||||
assert below_zero([1, -2]) == True
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "mean_absolute_deviation",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def mean_absolute_deviation(numbers: List[float]) -> float:
|
||||
@ -144,17 +161,21 @@ PROBLEMS = [
|
||||
>>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0])
|
||||
1.0
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0]) - 1.0) < 1e-6
|
||||
assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0, 5.0]) - 1.2) < 1e-6
|
||||
assert abs(mean_absolute_deviation([1.0, 1.0, 1.0, 1.0]) - 0.0) < 1e-6
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "intersperse",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def intersperse(numbers: List[int], delimiter: int) -> List[int]:
|
||||
@ -164,17 +185,21 @@ PROBLEMS = [
|
||||
>>> intersperse([1, 2, 3], 4)
|
||||
[1, 4, 2, 4, 3]
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert intersperse([], 7) == []
|
||||
assert intersperse([5, 6, 3, 2], 8) == [5, 8, 6, 8, 3, 8, 2]
|
||||
assert intersperse([2, 2, 2], 2) == [2, 2, 2, 2, 2]
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "parse_nested_parens",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def parse_nested_parens(paren_string: str) -> List[int]:
|
||||
@ -184,17 +209,21 @@ PROBLEMS = [
|
||||
>>> parse_nested_parens('(()()) ((())) () ((())())')
|
||||
[2, 3, 1, 3]
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert parse_nested_parens('(()()) ((())) () ((())())') == [2, 3, 1, 3]
|
||||
assert parse_nested_parens('() (()) ((())) (((())))') == [1, 2, 3, 4]
|
||||
assert parse_nested_parens('(()(())((())))') == [4]
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "filter_by_substring",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def filter_by_substring(strings: List[str], substring: str) -> List[str]:
|
||||
@ -204,18 +233,22 @@ PROBLEMS = [
|
||||
>>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')
|
||||
['abc', 'bacd', 'array']
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert filter_by_substring([], 'john') == []
|
||||
assert filter_by_substring(['xxx', 'asd', 'xxy', 'john doe', 'xxxuj', 'xxx'], 'xxx') == ['xxx', 'xxxuj', 'xxx']
|
||||
assert filter_by_substring(['xxx', 'asd', 'aaber', 'john doe', 'xxxuj', 'xxx'], 'xx') == ['xxx', 'xxxuj', 'xxx']
|
||||
assert filter_by_substring(['grunt', 'hierarchial', 'abc', 'hierarchial'], 'hi') == ['hierarchial', 'hierarchial']
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "sum_product",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List, Tuple
|
||||
|
||||
def sum_product(numbers: List[int]) -> Tuple[int, int]:
|
||||
@ -226,19 +259,23 @@ PROBLEMS = [
|
||||
>>> sum_product([1, 2, 3, 4])
|
||||
(10, 24)
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert sum_product([]) == (0, 1)
|
||||
assert sum_product([1, 1, 1]) == (3, 1)
|
||||
assert sum_product([100, 0]) == (100, 0)
|
||||
assert sum_product([3, 5, 7]) == (15, 105)
|
||||
assert sum_product([10]) == (10, 10)
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "max_element",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def max_element(l: List[int]) -> int:
|
||||
@ -248,17 +285,21 @@ PROBLEMS = [
|
||||
>>> max_element([5, 3, -5, 2, -3, 3, 9, 0, 123, 1, -10])
|
||||
123
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert max_element([1, 2, 3]) == 3
|
||||
assert max_element([5, 3, -5, 2, -3, 3, 9, 0, 124, 1, -10]) == 124
|
||||
assert max_element([-1, -2, -3]) == -1
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "fizz_buzz",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
def fizz_buzz(n: int) -> int:
|
||||
\"\"\"Return the number of times the digit 7 appears in integers less than n which are divisible by 11 or 13.
|
||||
>>> fizz_buzz(50)
|
||||
@ -268,8 +309,10 @@ PROBLEMS = [
|
||||
>>> fizz_buzz(79)
|
||||
3
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert fizz_buzz(50) == 0
|
||||
assert fizz_buzz(78) == 2
|
||||
assert fizz_buzz(79) == 3
|
||||
@ -277,11 +320,13 @@ PROBLEMS = [
|
||||
assert fizz_buzz(200) == 6
|
||||
assert fizz_buzz(4000) == 192
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "sort_by_binary_len",
|
||||
"prompt": textwrap.dedent("""\
|
||||
"prompt": textwrap.dedent(
|
||||
"""\
|
||||
from typing import List
|
||||
|
||||
def sort_array(arr: List[int]) -> List[int]:
|
||||
@ -294,8 +339,10 @@ PROBLEMS = [
|
||||
>>> sort_array([1, 0, 2, 3, 4])
|
||||
[0, 1, 2, 4, 3]
|
||||
\"\"\"
|
||||
"""),
|
||||
"tests": textwrap.dedent("""\
|
||||
"""
|
||||
),
|
||||
"tests": textwrap.dedent(
|
||||
"""\
|
||||
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 4, 3, 5]
|
||||
assert sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
|
||||
assert sort_array([1, 0, 2, 3, 4]) == [0, 1, 2, 4, 3]
|
||||
@ -303,7 +350,8 @@ PROBLEMS = [
|
||||
assert sort_array([2, 5, 77, 4, 5, 3, 5, 7, 2, 3, 4]) == [2, 2, 4, 4, 3, 3, 5, 5, 5, 7, 77]
|
||||
assert sort_array([3, 6, 44, 12, 32, 5]) == [32, 3, 5, 6, 12, 44]
|
||||
print("PASSED")
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
@ -312,7 +360,8 @@ PROBLEMS = [
|
||||
# compressor to identify and drop them.
|
||||
DISTRACTOR_SNIPPETS = [
|
||||
# distractor 0 — database connection pool
|
||||
textwrap.dedent("""\
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
# db_pool.py
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
@ -359,9 +408,11 @@ DISTRACTOR_SNIPPETS = [
|
||||
for conn in self._pool:
|
||||
conn.close()
|
||||
self._pool.clear()
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
# distractor 1 — HTTP retry logic
|
||||
textwrap.dedent("""\
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
# http_retry.py
|
||||
import time
|
||||
import random
|
||||
@ -405,9 +456,11 @@ DISTRACTOR_SNIPPETS = [
|
||||
resp = requests.get(url, params=params, timeout=30)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
# distractor 2 — LRU cache implementation
|
||||
textwrap.dedent("""\
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
# lru_cache.py
|
||||
from collections import OrderedDict
|
||||
from threading import RLock
|
||||
@ -456,9 +509,11 @@ DISTRACTOR_SNIPPETS = [
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._cache
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
# distractor 3 — CSV report generator
|
||||
textwrap.dedent("""\
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
# report_gen.py
|
||||
import csv
|
||||
import io
|
||||
@ -511,9 +566,11 @@ DISTRACTOR_SNIPPETS = [
|
||||
except (ValueError, KeyError):
|
||||
return False
|
||||
return self.filter_rows(in_range)
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
# distractor 4 — async task queue
|
||||
textwrap.dedent("""\
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
# task_queue.py
|
||||
import asyncio
|
||||
import logging
|
||||
@ -583,9 +640,11 @@ DISTRACTOR_SNIPPETS = [
|
||||
async def shutdown(self):
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
# distractor 5 — config parser with env var interpolation
|
||||
textwrap.dedent("""\
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
# config_parser.py
|
||||
import os
|
||||
import re
|
||||
@ -648,7 +707,8 @@ DISTRACTOR_SNIPPETS = [
|
||||
if val is None:
|
||||
raise ConfigError(f"Required config key missing: {key}")
|
||||
return val
|
||||
"""),
|
||||
"""
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from litellm.batches.batch_utils import (
|
||||
from litellm.cost_calculator import batch_cost_calculator
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
|
||||
# --- helpers ---
|
||||
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
SERVER_URL = "https://exampleopenaiendpoint-production-0ee2.up.railway.app/v1"
|
||||
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ import litellm
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.litellm_core_utils.token_counter import token_counter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -2,6 +2,7 @@ import ast
|
||||
import os
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
ALLOWED_FILE = os.path.normpath("litellm/_uuid.py")
|
||||
|
||||
|
||||
|
||||
@ -23,6 +23,8 @@ def event_loop():
|
||||
loop.close()
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown():
|
||||
"""
|
||||
|
||||
@ -957,18 +957,14 @@ async def test_langfuse_callback_failure_metric(prometheus_logger):
|
||||
# Get initial value
|
||||
initial_value = 0
|
||||
try:
|
||||
initial_value = (
|
||||
prometheus_logger.litellm_callback_logging_failures_metric.labels(
|
||||
callback_name="langfuse"
|
||||
)._value.get()
|
||||
)
|
||||
initial_value = prometheus_logger.litellm_callback_logging_failures_metric.labels(
|
||||
callback_name="langfuse"
|
||||
)._value.get()
|
||||
except Exception:
|
||||
initial_value = 0
|
||||
|
||||
# Create Langfuse logger with mocked initialization
|
||||
with patch(
|
||||
"litellm.integrations.langfuse.langfuse_prompt_management.langfuse_client_init"
|
||||
):
|
||||
with patch("litellm.integrations.langfuse.langfuse_prompt_management.langfuse_client_init"):
|
||||
langfuse_logger = LangfusePromptManagement()
|
||||
|
||||
# Mock the log_event_on_langfuse to raise an exception
|
||||
@ -980,12 +976,10 @@ async def test_langfuse_callback_failure_metric(prometheus_logger):
|
||||
mock_get_logger.return_value = mock_logger
|
||||
|
||||
# Mock handle_callback_failure to track calls
|
||||
with patch.object(
|
||||
prometheus_logger, "increment_callback_logging_failure"
|
||||
) as mock_increment:
|
||||
with patch.object(prometheus_logger, "increment_callback_logging_failure") as mock_increment:
|
||||
# Inject prometheus logger into the langfuse logger
|
||||
langfuse_logger.handle_callback_failure = (
|
||||
lambda callback_name: mock_increment(callback_name=callback_name)
|
||||
langfuse_logger.handle_callback_failure = lambda callback_name: mock_increment(
|
||||
callback_name=callback_name
|
||||
)
|
||||
|
||||
# Call async_log_success_event - should catch exception and increment metric
|
||||
@ -1017,51 +1011,43 @@ async def test_langfuse_otel_callback_failure_metric(prometheus_logger):
|
||||
# Get initial value
|
||||
initial_value = 0
|
||||
try:
|
||||
initial_value = (
|
||||
prometheus_logger.litellm_callback_logging_failures_metric.labels(
|
||||
callback_name="langfuse_otel"
|
||||
)._value.get()
|
||||
)
|
||||
initial_value = prometheus_logger.litellm_callback_logging_failures_metric.labels(
|
||||
callback_name="langfuse_otel"
|
||||
)._value.get()
|
||||
except Exception:
|
||||
initial_value = 0
|
||||
|
||||
# Create Langfuse OTEL logger with mocked initialization
|
||||
with patch(
|
||||
"litellm.integrations.opentelemetry.OpenTelemetry.__init__", return_value=None
|
||||
):
|
||||
with patch("litellm.integrations.opentelemetry.OpenTelemetry.__init__", return_value=None):
|
||||
langfuse_otel_logger = LangfuseOtelLogger(callback_name="langfuse_otel")
|
||||
langfuse_otel_logger.callback_name = "langfuse_otel"
|
||||
|
||||
# Mock handle_callback_failure to track calls
|
||||
with patch.object(
|
||||
prometheus_logger, "increment_callback_logging_failure"
|
||||
) as mock_increment:
|
||||
with patch.object(prometheus_logger, "increment_callback_logging_failure") as mock_increment:
|
||||
# Inject prometheus logger into the langfuse otel logger
|
||||
langfuse_otel_logger.handle_callback_failure = (
|
||||
lambda callback_name: mock_increment(callback_name=callback_name)
|
||||
langfuse_otel_logger.handle_callback_failure = lambda callback_name: mock_increment(
|
||||
callback_name=callback_name
|
||||
)
|
||||
|
||||
# Test that the OpenTelemetry base class set_attributes exception handler works
|
||||
# This is where langfuse_otel failures are caught and tracked
|
||||
with patch.object(
|
||||
langfuse_otel_logger, "set_attributes"
|
||||
) as mock_set_attributes:
|
||||
with patch.object(langfuse_otel_logger, "set_attributes") as mock_set_attributes:
|
||||
# Simulate the exception handling in set_attributes
|
||||
def set_attributes_with_error(*args, **kwargs):
|
||||
# This simulates what happens in the real set_attributes method
|
||||
try:
|
||||
raise Exception("Attribute error")
|
||||
except Exception as e:
|
||||
langfuse_otel_logger.handle_callback_failure(
|
||||
callback_name=langfuse_otel_logger.callback_name
|
||||
)
|
||||
langfuse_otel_logger.handle_callback_failure(callback_name=langfuse_otel_logger.callback_name)
|
||||
|
||||
mock_set_attributes.side_effect = set_attributes_with_error
|
||||
|
||||
# Call set_attributes
|
||||
try:
|
||||
langfuse_otel_logger.set_attributes(
|
||||
span=MagicMock(), kwargs={}, response_obj={}
|
||||
span=MagicMock(),
|
||||
kwargs={},
|
||||
response_obj={}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -49,13 +49,10 @@ async def test_enterprise_custom_auth_returns_string():
|
||||
mock_user_auth = AsyncMock(return_value="sk-test-key")
|
||||
request = MagicMock(spec=Request)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.enterprise_custom_auth",
|
||||
mock_user_auth,
|
||||
),
|
||||
patch("litellm.proxy.proxy_server.master_key", "sk-1234"),
|
||||
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
|
||||
with patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.enterprise_custom_auth", mock_user_auth
|
||||
), patch("litellm.proxy.proxy_server.master_key", "sk-1234"), patch(
|
||||
"litellm.proxy.proxy_server.prisma_client", MagicMock()
|
||||
):
|
||||
# Verify the key is correctly handled in _user_api_key_auth_builder
|
||||
with patch(
|
||||
|
||||
@ -105,9 +105,7 @@ async def test_apply_guardrail_endpoint_with_presidio_guardrail():
|
||||
mock_guardrail = Mock(spec=CustomGuardrail)
|
||||
# Simulate masking PII entities - returns GenericGuardrailAPIInputs (dict with texts key)
|
||||
mock_guardrail.apply_guardrail = AsyncMock(
|
||||
return_value={
|
||||
"texts": ["My name is [PERSON] and my email is [EMAIL_ADDRESS]"]
|
||||
}
|
||||
return_value={"texts": ["My name is [PERSON] and my email is [EMAIL_ADDRESS]"]}
|
||||
)
|
||||
|
||||
# Configure the registry to return our mock guardrail
|
||||
|
||||
@ -95,9 +95,9 @@ async def test_async_pre_call_deployment_hook_resolves_model_id_from_litellm_met
|
||||
kwargs=kwargs, call_type=CallTypes.acreate_batch
|
||||
)
|
||||
|
||||
assert (
|
||||
result["input_file_id"] == provider_file_id
|
||||
), f"Expected provider file ID '{provider_file_id}', got '{result['input_file_id']}'"
|
||||
assert result["input_file_id"] == provider_file_id, (
|
||||
f"Expected provider file ID '{provider_file_id}', got '{result['input_file_id']}'"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -134,9 +134,9 @@ async def test_async_pre_call_deployment_hook_prefers_top_level_model_info():
|
||||
kwargs=kwargs, call_type=CallTypes.acreate_batch
|
||||
)
|
||||
|
||||
assert (
|
||||
result["input_file_id"] == top_level_provider_file
|
||||
), "Should prefer top-level model_info over litellm_metadata"
|
||||
assert result["input_file_id"] == top_level_provider_file, (
|
||||
"Should prefer top-level model_info over litellm_metadata"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -162,9 +162,9 @@ async def test_async_pre_call_deployment_hook_no_model_info_leaves_file_id_uncha
|
||||
kwargs=kwargs, call_type=CallTypes.acreate_batch
|
||||
)
|
||||
|
||||
assert (
|
||||
result["input_file_id"] == managed_file_id
|
||||
), "File ID should remain unchanged when model_info is not available"
|
||||
assert result["input_file_id"] == managed_file_id, (
|
||||
"File ID should remain unchanged when model_info is not available"
|
||||
)
|
||||
|
||||
|
||||
# def test_list_managed_files():
|
||||
@ -341,9 +341,7 @@ async def test_async_pre_call_hook_for_unified_finetuning_job():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"call_type", ["afile_content", "afile_delete", "afile_retrieve"]
|
||||
)
|
||||
@pytest.mark.parametrize("call_type", ["afile_content", "afile_delete", "afile_retrieve"])
|
||||
async def test_can_user_call_unified_file_id(call_type):
|
||||
"""
|
||||
Test that on file retrieve, delete, and content we check if the user has access to the file
|
||||
@ -624,7 +622,8 @@ async def test_error_file_id_for_failed_batch():
|
||||
mock_retrieve.return_value = error_file_object
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
user_id="test-user-123", parent_otel_span=MagicMock()
|
||||
user_id="test-user-123",
|
||||
parent_otel_span=MagicMock()
|
||||
)
|
||||
|
||||
response = await proxy_managed_files.async_post_call_success_hook(
|
||||
@ -637,9 +636,7 @@ async def test_error_file_id_for_failed_batch():
|
||||
assert cast(LiteLLMBatch, response).error_file_id is not None
|
||||
assert not cast(LiteLLMBatch, response).error_file_id.startswith("error-")
|
||||
# Verify it's a base64 encoded managed file ID
|
||||
assert _is_base64_encoded_unified_file_id(
|
||||
cast(LiteLLMBatch, response).error_file_id
|
||||
)
|
||||
assert _is_base64_encoded_unified_file_id(cast(LiteLLMBatch, response).error_file_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -681,10 +678,8 @@ async def test_async_post_call_success_hook_twice_assert_no_unique_violation():
|
||||
# first retrieve batch
|
||||
tasks = []
|
||||
first_create_task = asyncio.create_task
|
||||
with patch("asyncio.create_task") as mock_create_task:
|
||||
mock_create_task.side_effect = (
|
||||
lambda coro: tasks.append(first_create_task(coro)) or tasks[-1]
|
||||
)
|
||||
with patch('asyncio.create_task') as mock_create_task:
|
||||
mock_create_task.side_effect = lambda coro: tasks.append(first_create_task(coro)) or tasks[-1]
|
||||
|
||||
response = await proxy_managed_files.async_post_call_success_hook(
|
||||
data={},
|
||||
@ -705,10 +700,8 @@ async def test_async_post_call_success_hook_twice_assert_no_unique_violation():
|
||||
# second retrieve batch
|
||||
tasks = []
|
||||
second_create_task = asyncio.create_task
|
||||
with patch("asyncio.create_task") as mock_create_task:
|
||||
mock_create_task.side_effect = (
|
||||
lambda coro: tasks.append(second_create_task(coro)) or tasks[-1]
|
||||
)
|
||||
with patch('asyncio.create_task') as mock_create_task:
|
||||
mock_create_task.side_effect = lambda coro: tasks.append(second_create_task(coro)) or tasks[-1]
|
||||
|
||||
await proxy_managed_files.async_post_call_success_hook(
|
||||
data={},
|
||||
@ -760,10 +753,7 @@ def test_update_responses_input_with_unified_file_id():
|
||||
assert updated_input[0]["content"][0]["type"] == "input_file"
|
||||
assert updated_input[0]["content"][0]["file_id"] == "file-ECBPW7ML9g7XHdwGgUPZaM"
|
||||
assert updated_input[0]["content"][1]["type"] == "input_text"
|
||||
assert (
|
||||
updated_input[0]["content"][1]["text"]
|
||||
== "What is the first dragon in the book?"
|
||||
)
|
||||
assert updated_input[0]["content"][1]["text"] == "What is the first dragon in the book?"
|
||||
|
||||
|
||||
def test_update_responses_input_with_regular_file_id():
|
||||
@ -965,10 +955,7 @@ def test_update_responses_tools_with_model_file_id_mapping():
|
||||
|
||||
# Verify the file IDs were mapped to provider-specific file IDs
|
||||
assert updated_tools[0]["type"] == "code_interpreter"
|
||||
assert updated_tools[0]["container"]["file_ids"] == [
|
||||
"openai_file_abc",
|
||||
"openai_file_def",
|
||||
]
|
||||
assert updated_tools[0]["container"]["file_ids"] == ["openai_file_abc", "openai_file_def"]
|
||||
|
||||
|
||||
def test_update_responses_tools_without_mapping():
|
||||
@ -1039,10 +1026,7 @@ def test_update_responses_tools_with_mixed_file_ids():
|
||||
)
|
||||
|
||||
# Verify managed file ID was mapped and regular file ID was kept
|
||||
assert updated_tools[0]["container"]["file_ids"] == [
|
||||
"openai_file_abc",
|
||||
regular_file_id,
|
||||
]
|
||||
assert updated_tools[0]["container"]["file_ids"] == ["openai_file_abc", regular_file_id]
|
||||
|
||||
|
||||
def test_get_file_ids_from_responses_tools():
|
||||
@ -1371,9 +1355,7 @@ async def test_store_unified_file_id_with_none_file_object():
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
prisma_client = AsyncMock()
|
||||
prisma_client.db.litellm_managedfiletable.create = AsyncMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
prisma_client.db.litellm_managedfiletable.create = AsyncMock(return_value=MagicMock())
|
||||
internal_usage_cache = MagicMock()
|
||||
internal_usage_cache.async_set_cache = AsyncMock()
|
||||
|
||||
@ -1411,22 +1393,18 @@ async def test_afile_delete_returns_provider_response_when_stored_file_object_no
|
||||
prisma_client = AsyncMock()
|
||||
db_record = MagicMock()
|
||||
db_record.model_mappings = '{"model-123": "file-provider-xyz"}'
|
||||
prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(
|
||||
return_value=db_record
|
||||
)
|
||||
prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(return_value=db_record)
|
||||
prisma_client.db.litellm_managedfiletable.delete = AsyncMock()
|
||||
|
||||
internal_usage_cache = MagicMock()
|
||||
internal_usage_cache.async_get_cache = AsyncMock(
|
||||
return_value={
|
||||
"unified_file_id": unified_file_id,
|
||||
"model_mappings": {"model-123": "file-provider-xyz"},
|
||||
"flat_model_file_ids": ["file-provider-xyz"],
|
||||
"file_object": None,
|
||||
"created_by": "test-user",
|
||||
"updated_by": "test-user",
|
||||
}
|
||||
)
|
||||
internal_usage_cache.async_get_cache = AsyncMock(return_value={
|
||||
"unified_file_id": unified_file_id,
|
||||
"model_mappings": {"model-123": "file-provider-xyz"},
|
||||
"flat_model_file_ids": ["file-provider-xyz"],
|
||||
"file_object": None,
|
||||
"created_by": "test-user",
|
||||
"updated_by": "test-user",
|
||||
})
|
||||
internal_usage_cache.async_set_cache = AsyncMock()
|
||||
|
||||
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
|
||||
@ -1494,12 +1472,10 @@ async def test_afile_retrieve_fetches_from_provider_when_file_object_none():
|
||||
)
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_router.get_deployment_credentials_with_provider = MagicMock(
|
||||
return_value={
|
||||
"api_key": "test-key",
|
||||
"api_base": "https://api.openai.com",
|
||||
}
|
||||
)
|
||||
mock_router.get_deployment_credentials_with_provider = MagicMock(return_value={
|
||||
"api_key": "test-key",
|
||||
"api_base": "https://api.openai.com",
|
||||
})
|
||||
|
||||
with patch("litellm.afile_retrieve", new_callable=AsyncMock) as mock_afile_retrieve:
|
||||
mock_afile_retrieve.return_value = provider_file_response
|
||||
@ -1624,33 +1600,29 @@ async def test_list_batches_from_managed_objects_table():
|
||||
|
||||
batch_record_1 = MagicMock()
|
||||
batch_record_1.unified_object_id = "unified-batch-id-1"
|
||||
batch_record_1.file_object = json.dumps(
|
||||
{
|
||||
"id": "batch_abc123",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "completed",
|
||||
"created_at": 1234567890,
|
||||
"input_file_id": "file-input-1",
|
||||
"request_counts": {"total": 1, "completed": 1, "failed": 0},
|
||||
}
|
||||
)
|
||||
batch_record_1.file_object = json.dumps({
|
||||
"id": "batch_abc123",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "completed",
|
||||
"created_at": 1234567890,
|
||||
"input_file_id": "file-input-1",
|
||||
"request_counts": {"total": 1, "completed": 1, "failed": 0},
|
||||
})
|
||||
|
||||
batch_record_2 = MagicMock()
|
||||
batch_record_2.unified_object_id = "unified-batch-id-2"
|
||||
batch_record_2.file_object = json.dumps(
|
||||
{
|
||||
"id": "batch_xyz789",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "in_progress",
|
||||
"created_at": 1234567891,
|
||||
"input_file_id": "file-input-2",
|
||||
"request_counts": {"total": 5, "completed": 2, "failed": 0},
|
||||
}
|
||||
)
|
||||
batch_record_2.file_object = json.dumps({
|
||||
"id": "batch_xyz789",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "in_progress",
|
||||
"created_at": 1234567891,
|
||||
"input_file_id": "file-input-2",
|
||||
"request_counts": {"total": 5, "completed": 2, "failed": 0},
|
||||
})
|
||||
|
||||
prisma_client.db.litellm_managedobjecttable.find_many.return_value = [
|
||||
batch_record_1,
|
||||
@ -1713,7 +1685,6 @@ async def test_list_batches_from_managed_objects_table_empty_list():
|
||||
|
||||
def _create_unified_batch_id(model_id: str, batch_id: str) -> str:
|
||||
import base64
|
||||
|
||||
unified_str = f"litellm_proxy;model_id:{model_id};llm_batch_id:{batch_id}"
|
||||
return base64.urlsafe_b64encode(unified_str.encode()).decode().rstrip("=")
|
||||
|
||||
@ -1769,7 +1740,6 @@ async def test_list_batches_from_managed_objects_table_target_model_name_filter_
|
||||
# Verify find_many was NOT called since exception is raised before database query
|
||||
prisma_client.db.litellm_managedobjecttable.find_many.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_batches_from_managed_objects_table_filters_by_created_by():
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
@ -1779,34 +1749,30 @@ async def test_list_batches_from_managed_objects_table_filters_by_created_by():
|
||||
# Create batch for user1
|
||||
batch_user1 = MagicMock()
|
||||
batch_user1.unified_object_id = "unified-batch-user1"
|
||||
batch_user1.file_object = json.dumps(
|
||||
{
|
||||
"id": "batch_user1_abc",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "completed",
|
||||
"created_at": 1234567890,
|
||||
"input_file_id": "file-input-user1",
|
||||
"request_counts": {"total": 1, "completed": 1, "failed": 0},
|
||||
}
|
||||
)
|
||||
batch_user1.file_object = json.dumps({
|
||||
"id": "batch_user1_abc",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "completed",
|
||||
"created_at": 1234567890,
|
||||
"input_file_id": "file-input-user1",
|
||||
"request_counts": {"total": 1, "completed": 1, "failed": 0},
|
||||
})
|
||||
|
||||
# Create batch for user2
|
||||
batch_user2 = MagicMock()
|
||||
batch_user2.unified_object_id = "unified-batch-user2"
|
||||
batch_user2.file_object = json.dumps(
|
||||
{
|
||||
"id": "batch_user2_xyz",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "completed",
|
||||
"created_at": 1234567891,
|
||||
"input_file_id": "file-input-user2",
|
||||
"request_counts": {"total": 2, "completed": 2, "failed": 0},
|
||||
}
|
||||
)
|
||||
batch_user2.file_object = json.dumps({
|
||||
"id": "batch_user2_xyz",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "completed",
|
||||
"created_at": 1234567891,
|
||||
"input_file_id": "file-input-user2",
|
||||
"request_counts": {"total": 2, "completed": 2, "failed": 0},
|
||||
})
|
||||
|
||||
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
|
||||
DualCache(), prisma_client=prisma_client
|
||||
@ -1913,9 +1879,7 @@ async def test_user_b_cannot_retrieve_user_a_batch():
|
||||
)
|
||||
|
||||
# User B tries to retrieve User A's batch
|
||||
unified_batch_id = (
|
||||
"bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
|
||||
)
|
||||
unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await proxy_managed_files.async_pre_call_hook(
|
||||
@ -1950,9 +1914,7 @@ async def test_user_b_cannot_cancel_user_a_batch():
|
||||
)
|
||||
|
||||
# User B tries to cancel User A's batch
|
||||
unified_batch_id = (
|
||||
"bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
|
||||
)
|
||||
unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await proxy_managed_files.async_pre_call_hook(
|
||||
@ -1990,9 +1952,7 @@ async def test_user_a_can_retrieve_own_batch():
|
||||
)
|
||||
|
||||
# User A retrieves their own batch
|
||||
unified_batch_id = (
|
||||
"bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
|
||||
)
|
||||
unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
|
||||
|
||||
# Should not raise an exception
|
||||
result = await proxy_managed_files.async_pre_call_hook(
|
||||
@ -2028,9 +1988,7 @@ async def test_user_b_cannot_retrieve_user_a_file():
|
||||
)
|
||||
|
||||
# User B tries to retrieve User A's file
|
||||
unified_file_id = (
|
||||
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
|
||||
)
|
||||
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await proxy_managed_files.async_pre_call_hook(
|
||||
@ -2065,9 +2023,7 @@ async def test_user_b_cannot_download_user_a_file_content():
|
||||
)
|
||||
|
||||
# User B tries to download User A's file content
|
||||
unified_file_id = (
|
||||
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
|
||||
)
|
||||
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await proxy_managed_files.async_pre_call_hook(
|
||||
@ -2102,9 +2058,7 @@ async def test_user_b_cannot_delete_user_a_file():
|
||||
)
|
||||
|
||||
# User B tries to delete User A's file
|
||||
unified_file_id = (
|
||||
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
|
||||
)
|
||||
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await proxy_managed_files.async_pre_call_hook(
|
||||
@ -2135,16 +2089,14 @@ async def test_user_a_can_retrieve_own_file():
|
||||
file_record = MagicMock()
|
||||
file_record.created_by = "user_a_id"
|
||||
file_record.model_mappings = '{"model-123": "file-abc123"}'
|
||||
file_record.file_object = json.dumps(
|
||||
{
|
||||
"id": "file-abc123",
|
||||
"object": "file",
|
||||
"bytes": 1234,
|
||||
"created_at": 1234567890,
|
||||
"filename": "test.jsonl",
|
||||
"purpose": "batch",
|
||||
}
|
||||
)
|
||||
file_record.file_object = json.dumps({
|
||||
"id": "file-abc123",
|
||||
"object": "file",
|
||||
"bytes": 1234,
|
||||
"created_at": 1234567890,
|
||||
"filename": "test.jsonl",
|
||||
"purpose": "batch",
|
||||
})
|
||||
prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record
|
||||
|
||||
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
|
||||
@ -2152,9 +2104,7 @@ async def test_user_a_can_retrieve_own_file():
|
||||
)
|
||||
|
||||
# User A retrieves their own file
|
||||
unified_file_id = (
|
||||
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
|
||||
)
|
||||
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM"
|
||||
|
||||
# Should not raise an exception
|
||||
result = await proxy_managed_files.async_pre_call_hook(
|
||||
@ -2184,18 +2134,16 @@ async def test_list_batches_only_returns_user_own_batches():
|
||||
# Create batches for User A
|
||||
batch_user_a = MagicMock()
|
||||
batch_user_a.unified_object_id = "batch-user-a"
|
||||
batch_user_a.file_object = json.dumps(
|
||||
{
|
||||
"id": "batch_a",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "completed",
|
||||
"created_at": 1234567890,
|
||||
"input_file_id": "file-a",
|
||||
"request_counts": {"total": 1, "completed": 1, "failed": 0},
|
||||
}
|
||||
)
|
||||
batch_user_a.file_object = json.dumps({
|
||||
"id": "batch_a",
|
||||
"object": "batch",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"status": "completed",
|
||||
"created_at": 1234567890,
|
||||
"input_file_id": "file-a",
|
||||
"request_counts": {"total": 1, "completed": 1, "failed": 0},
|
||||
})
|
||||
|
||||
# Mock database to only return User A's batches
|
||||
prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user_a]
|
||||
@ -2243,14 +2191,14 @@ async def test_same_user_different_keys_can_access_batch():
|
||||
DualCache(), prisma_client=prisma_client
|
||||
)
|
||||
|
||||
unified_batch_id = (
|
||||
"bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
|
||||
)
|
||||
unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz"
|
||||
|
||||
# First API key for User A retrieves the batch
|
||||
result1 = await proxy_managed_files.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_id="user_a_id", api_key="key-1", parent_otel_span=MagicMock()
|
||||
user_id="user_a_id",
|
||||
api_key="key-1",
|
||||
parent_otel_span=MagicMock()
|
||||
),
|
||||
cache=MagicMock(),
|
||||
data={"batch_id": unified_batch_id},
|
||||
@ -2262,7 +2210,9 @@ async def test_same_user_different_keys_can_access_batch():
|
||||
# Second API key for the same User A retrieves the batch
|
||||
result2 = await proxy_managed_files.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_id="user_a_id", api_key="key-2", parent_otel_span=MagicMock()
|
||||
user_id="user_a_id",
|
||||
api_key="key-2",
|
||||
parent_otel_span=MagicMock()
|
||||
),
|
||||
cache=MagicMock(),
|
||||
data={"batch_id": unified_batch_id},
|
||||
|
||||
@ -31,16 +31,12 @@ class TestAvailableEnterpriseUsers:
|
||||
self, client, mock_user_api_key_auth
|
||||
):
|
||||
"""Test when max_users is set and user count is within limit"""
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma,
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
True,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.premium_user_data",
|
||||
{"max_users": 10},
|
||||
),
|
||||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
True,
|
||||
), patch(
|
||||
"litellm.proxy.proxy_server.premium_user_data",
|
||||
{"max_users": 10},
|
||||
):
|
||||
# Mock database count
|
||||
mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=5)
|
||||
@ -70,16 +66,12 @@ class TestAvailableEnterpriseUsers:
|
||||
self, client, mock_user_api_key_auth
|
||||
):
|
||||
"""Test when max_users is not set (premium_user_data is None or doesn't contain max_users)"""
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma,
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
True,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.premium_user_data",
|
||||
None,
|
||||
),
|
||||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
True,
|
||||
), patch(
|
||||
"litellm.proxy.proxy_server.premium_user_data",
|
||||
None,
|
||||
):
|
||||
# Mock database count
|
||||
mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=3)
|
||||
@ -107,16 +99,12 @@ class TestAvailableEnterpriseUsers:
|
||||
self, client, mock_user_api_key_auth
|
||||
):
|
||||
"""Test the current bug where total_users_remaining can be negative"""
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma,
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
True,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.premium_user_data",
|
||||
{"key": "value"},
|
||||
),
|
||||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
True,
|
||||
), patch(
|
||||
"litellm.proxy.proxy_server.premium_user_data",
|
||||
{"key": "value"},
|
||||
):
|
||||
# Mock database count higher than max_users to trigger the bug
|
||||
mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=8)
|
||||
@ -152,15 +140,12 @@ class TestAvailableEnterpriseUsers:
|
||||
"""Test when prisma_client is None (no database connection)"""
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.prisma_client",
|
||||
None,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
True,
|
||||
),
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client",
|
||||
None,
|
||||
), patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
True,
|
||||
):
|
||||
# Override the dependency
|
||||
client.app.dependency_overrides[mock_user_api_key_auth] = lambda: {
|
||||
|
||||
@ -801,9 +801,7 @@ async def test_list_projects_returns_timestamps():
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from litellm_enterprise.proxy.management_endpoints.project_endpoints import (
|
||||
list_projects,
|
||||
)
|
||||
from litellm_enterprise.proxy.management_endpoints.project_endpoints import list_projects
|
||||
from litellm.proxy._types import LiteLLM_ProjectTable
|
||||
|
||||
now = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
@ -14,6 +14,7 @@ from litellm.proxy.guardrails.guardrail_registry import (
|
||||
)
|
||||
from litellm.proxy.guardrails.guardrail_hooks.akto.akto import AktoGuardrail
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -6,6 +6,7 @@ import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import asyncio
|
||||
|
||||
@ -21,6 +21,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor
|
||||
ContentFilterCategoryConfig,
|
||||
)
|
||||
|
||||
|
||||
# Test cases: (sentence, expected_result, reason)
|
||||
TEST_CASES = [
|
||||
# ALWAYS BLOCK - Explicit prohibited practices (1-10)
|
||||
|
||||
@ -23,6 +23,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor
|
||||
ContentFilterCategoryConfig,
|
||||
)
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
POLICY_DIR = os.path.abspath(
|
||||
|
||||
@ -28,6 +28,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor
|
||||
ContentFilterCategoryConfig,
|
||||
)
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
POLICY_DIR = os.path.abspath(
|
||||
|
||||
@ -7,6 +7,7 @@ import sys
|
||||
import traceback
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
@ -6,6 +6,7 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
@ -11,6 +11,7 @@ import pytest
|
||||
from litellm.types.utils import Embedding
|
||||
from litellm.main import bedrock_embedding, embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
_mock_model_id = (
|
||||
"arn:aws:bedrock:us-east-1:123412341234:application-inference-profile/abc123123"
|
||||
)
|
||||
|
||||
@ -11,6 +11,7 @@ from litellm.llms.bedrock.common_utils import (
|
||||
)
|
||||
from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
||||
|
||||
|
||||
NOVA_ARN = "arn:aws:bedrock:us-east-1:123456789012:custom-model-deployment/a1b2c3d4e5f6"
|
||||
NOVA_MODEL = f"bedrock/nova/{NOVA_ARN}"
|
||||
NOVA2_MODEL = f"bedrock/nova-2/{NOVA_ARN}"
|
||||
|
||||
@ -581,21 +581,17 @@ def test_vertex_ai_text_only_embedding_uses_embed_content():
|
||||
|
||||
def test_filter_embed_params_drops_unsupported():
|
||||
"""Unsupported params like max_tokens should be filtered out."""
|
||||
result = _filter_embed_params(
|
||||
{"dimensions": 768, "max_tokens": 256, "temperature": 0.5}
|
||||
)
|
||||
result = _filter_embed_params({"dimensions": 768, "max_tokens": 256, "temperature": 0.5})
|
||||
assert result == {"outputDimensionality": 768}
|
||||
|
||||
|
||||
def test_filter_embed_params_keeps_supported():
|
||||
"""All supported Gemini embedding params should pass through."""
|
||||
result = _filter_embed_params(
|
||||
{
|
||||
"dimensions": 768,
|
||||
"task_type": "RETRIEVAL_DOCUMENT",
|
||||
"title": "My doc",
|
||||
}
|
||||
)
|
||||
result = _filter_embed_params({
|
||||
"dimensions": 768,
|
||||
"task_type": "RETRIEVAL_DOCUMENT",
|
||||
"title": "My doc",
|
||||
})
|
||||
assert result == {
|
||||
"outputDimensionality": 768,
|
||||
"taskType": "RETRIEVAL_DOCUMENT",
|
||||
|
||||
@ -11,6 +11,7 @@ import pytest
|
||||
|
||||
from litellm import get_model_info
|
||||
|
||||
|
||||
MODEL_NAME = "nvidia.nemotron-super-3-120b"
|
||||
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
BedrockConverseMessagesProcessor,
|
||||
)
|
||||
|
||||
|
||||
MODEL = "anthropic.claude-v2"
|
||||
PROVIDER = "bedrock_converse"
|
||||
|
||||
@ -545,9 +546,9 @@ def test_bedrock_converse_sorts_text_before_tooluse_sync():
|
||||
tool_indices = [i for i, b in enumerate(content) if "toolUse" in b]
|
||||
|
||||
# All text blocks must come before all toolUse blocks
|
||||
assert max(text_indices) < min(
|
||||
tool_indices
|
||||
), f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}"
|
||||
assert max(text_indices) < min(tool_indices), (
|
||||
f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -566,9 +567,9 @@ async def test_bedrock_converse_sorts_text_before_tooluse_async():
|
||||
text_indices = [i for i, b in enumerate(content) if "text" in b]
|
||||
tool_indices = [i for i, b in enumerate(content) if "toolUse" in b]
|
||||
|
||||
assert max(text_indices) < min(
|
||||
tool_indices
|
||||
), f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}"
|
||||
assert max(text_indices) < min(tool_indices), (
|
||||
f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -576,9 +577,7 @@ async def test_bedrock_converse_content_ordering_sync_async_parity():
|
||||
"""Sync and async paths should produce identical content block ordering."""
|
||||
messages = _make_tooluse_before_text_messages()
|
||||
sync_result = _bedrock_converse_messages_pt(messages, MODEL, PROVIDER)
|
||||
async_result = (
|
||||
await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
|
||||
messages, MODEL, PROVIDER
|
||||
)
|
||||
async_result = await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
|
||||
messages, MODEL, PROVIDER
|
||||
)
|
||||
assert sync_result == async_result
|
||||
|
||||
@ -10,6 +10,7 @@ from dotenv import load_dotenv
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
# Fake Vertex AI Gemini response for mocking
|
||||
FAKE_VERTEX_GEMINI_RESPONSE = {
|
||||
"candidates": [
|
||||
|
||||
@ -12,6 +12,7 @@ from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
@ -387,7 +387,9 @@ def test_process_chunk_completed_response_updates_id_and_usage_cost(monkeypatch)
|
||||
# Chunk must include a top-level "response" key so BaseResponsesAPIStreamingIterator
|
||||
# runs _update_responses_api_response_id_with_model_id (see streaming_iterator.py).
|
||||
event = iterator._process_chunk(
|
||||
json.dumps({"type": "response.completed", "response": {"id": "resp_live"}})
|
||||
json.dumps(
|
||||
{"type": "response.completed", "response": {"id": "resp_live"}}
|
||||
)
|
||||
)
|
||||
finally:
|
||||
litellm.include_cost_in_streaming_usage = original_include_cost
|
||||
|
||||
@ -22,8 +22,10 @@ sys.path.insert(0, os.path.abspath("../.."))
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
|
||||
# Large document for caching tests (needs 1024+ tokens for Claude models)
|
||||
LARGE_DOCUMENT_FOR_CACHING = """
|
||||
LARGE_DOCUMENT_FOR_CACHING = (
|
||||
"""
|
||||
This is a comprehensive legal agreement between Party A and Party B.
|
||||
|
||||
ARTICLE 1: DEFINITIONS
|
||||
@ -75,7 +77,9 @@ ARTICLE 9: GENERAL PROVISIONS
|
||||
9.5 Waiver of any provision shall not constitute ongoing waiver.
|
||||
|
||||
IN WITNESS WHEREOF, the parties have executed this Agreement.
|
||||
""" * 8 # Repeat to ensure we have enough tokens (need 1024+ for Claude models)
|
||||
"""
|
||||
* 8
|
||||
) # Repeat to ensure we have enough tokens (need 1024+ for Claude models)
|
||||
|
||||
|
||||
class TestBedrockAnthropicPromptCachingRegression:
|
||||
|
||||
@ -3,6 +3,7 @@ import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Tests for Crusoe provider integration
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
@ -97,9 +96,9 @@ def test_crusoe_models_configuration():
|
||||
for model in crusoe_models:
|
||||
model_info = get_model_info(model)
|
||||
assert model_info is not None, f"Model info not found for {model}"
|
||||
assert (
|
||||
model_info.get("litellm_provider") == "crusoe"
|
||||
), f"{model} should have crusoe as provider"
|
||||
assert model_info.get("litellm_provider") == "crusoe", (
|
||||
f"{model} should have crusoe as provider"
|
||||
)
|
||||
assert model_info.get("mode") == "chat", f"{model} should be in chat mode"
|
||||
finally:
|
||||
litellm.model_cost = original_model_cost
|
||||
|
||||
@ -1362,12 +1362,8 @@ def test_anthropic_thinking_param_to_gemini_3_provider_defaults():
|
||||
)
|
||||
|
||||
# For Gemini 3, should not force thinkingLevel by default
|
||||
assert (
|
||||
"thinkingLevel" not in result
|
||||
), "Should not force thinkingLevel for Gemini 3"
|
||||
assert (
|
||||
"thinkingBudget" not in result
|
||||
), "Should NOT have thinkingBudget for Gemini 3"
|
||||
assert "thinkingLevel" not in result, "Should not force thinkingLevel for Gemini 3"
|
||||
assert "thinkingBudget" not in result, "Should NOT have thinkingBudget for Gemini 3"
|
||||
assert result["includeThoughts"] is True
|
||||
|
||||
# Test 2: Anthropic thinking disabled for Gemini 3
|
||||
@ -1399,10 +1395,7 @@ def test_anthropic_thinking_param_to_gemini_3_provider_defaults():
|
||||
)
|
||||
|
||||
assert result_zero["includeThoughts"] is False
|
||||
assert (
|
||||
"thinkingLevel" not in result_zero
|
||||
or result_zero.get("thinkingLevel") is None
|
||||
)
|
||||
assert "thinkingLevel" not in result_zero or result_zero.get("thinkingLevel") is None
|
||||
|
||||
# Test 4: Gemini 3 flash-preview should also follow provider defaults by default
|
||||
result_gemini3flashpreview = VertexGeminiConfig._map_thinking_param(
|
||||
@ -1532,12 +1525,8 @@ def test_anthropic_thinking_param_via_map_openai_params():
|
||||
# Check that thinkingConfig was created without forced thinkingLevel
|
||||
assert "thinkingConfig" in result, "Should have thinkingConfig in optional_params"
|
||||
thinking_config = result["thinkingConfig"]
|
||||
assert (
|
||||
"thinkingLevel" not in thinking_config
|
||||
), "Should not force thinkingLevel for Gemini 3 by default"
|
||||
assert (
|
||||
"thinkingBudget" not in thinking_config
|
||||
), "Should NOT have thinkingBudget for Gemini 3"
|
||||
assert "thinkingLevel" not in thinking_config, "Should not force thinkingLevel for Gemini 3 by default"
|
||||
assert "thinkingBudget" not in thinking_config, "Should NOT have thinkingBudget for Gemini 3"
|
||||
assert thinking_config["includeThoughts"] is True
|
||||
|
||||
# Test with Gemini 2 model
|
||||
|
||||
@ -2,6 +2,7 @@ import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import asyncio
|
||||
|
||||
@ -38,6 +38,7 @@ from litellm.llms.vertex_ai.gemini.transformation import (
|
||||
)
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
user_message = "Write a short poem about the sky"
|
||||
@ -1103,7 +1104,9 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs):
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{"text": """{
|
||||
"parts": [
|
||||
{
|
||||
"text": """{
|
||||
"recipes": [
|
||||
{"recipe_name": "Chocolate Chip Cookies"},
|
||||
{"recipe_name": "Oatmeal Raisin Cookies"},
|
||||
@ -1111,7 +1114,9 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs):
|
||||
{"recipe_name": "Sugar Cookies"},
|
||||
{"recipe_name": "Snickerdoodles"}
|
||||
]
|
||||
}"""}],
|
||||
}"""
|
||||
}
|
||||
],
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [
|
||||
|
||||
@ -450,9 +450,12 @@ def c():
|
||||
litellm.enable_caching_on_provider_specific_optional_params = False
|
||||
|
||||
|
||||
embedding_large_text = """
|
||||
embedding_large_text = (
|
||||
"""
|
||||
small text
|
||||
""" * 5
|
||||
"""
|
||||
* 5
|
||||
)
|
||||
|
||||
|
||||
# # test_caching_with_models()
|
||||
|
||||
@ -1638,13 +1638,15 @@ def custom_callback(
|
||||
|
||||
#################################################
|
||||
|
||||
print(f"""
|
||||
print(
|
||||
f"""
|
||||
Model: {model},
|
||||
Messages: {messages},
|
||||
User: {user},
|
||||
Seed: {kwargs["seed"]},
|
||||
temperature: {kwargs["temperature"]},
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
assert kwargs["user"] == "ishaans app"
|
||||
assert kwargs["model"] == "gpt-3.5-turbo-1106"
|
||||
|
||||
@ -93,7 +93,8 @@ class testCustomCallbackProxy(CustomLogger):
|
||||
|
||||
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
|
||||
|
||||
print(f"""
|
||||
print(
|
||||
f"""
|
||||
Model: {model},
|
||||
Messages: {messages},
|
||||
User: {user},
|
||||
@ -101,7 +102,8 @@ class testCustomCallbackProxy(CustomLogger):
|
||||
Cost: {cost},
|
||||
Response: {response}
|
||||
Proxy Metadata: {metadata}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user