Drop dep bumps + black-26 reformat to clear fork CI policy
PR was blocked by .github/workflows/guard-fork-dependencies.yml: fork PRs cannot modify uv.lock. Reverting: - uv.lock + pyproject.toml black bump (24.10.0 -> 26.3.1) and the 295 files of mechanical Black 26 reformat coupled to it - pyproject.toml diskcache extra change (kept the runtime mitigation in litellm/caching/disk_cache.py via JSONDisk) Kept: - Dockerfile cache narrowing (drops ~660 MB of uv build cache that surfaced cached setuptools as CVE findings) - litellm/caching/disk_cache.py: dc.JSONDisk to neutralize CVE-2025-69872 - ui/litellm-dashboard/package-lock.json + litellm-js/spend-logs/package-lock.json: next/postcss/hono/uuid CVE bumps (these are not blocked by the fork guard) - tests/test_litellm/caching/test_disk_cache.py - tests/code_coverage_tests/liccheck.ini: harmless black authorization Black + gitpython + langchain dep upgrades will need a follow-up from a maintainer pushing a branch in the canonical BerriAI/litellm repo. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
63bda3f001
commit
5bafa8b3a2
@ -2,7 +2,6 @@ import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
|
||||
# Asynchronously fetch data from a given URL
|
||||
async def fetch_data(url):
|
||||
try:
|
||||
@ -16,24 +15,22 @@ async def fetch_data(url):
|
||||
resp_json = await resp.json()
|
||||
print("Fetch the data from URL.")
|
||||
# Return the 'data' field from the JSON response
|
||||
return resp_json["data"]
|
||||
return resp_json['data']
|
||||
except Exception as e:
|
||||
# Print an error message if fetching data fails
|
||||
print("Error fetching data from URL:", e)
|
||||
return None
|
||||
|
||||
|
||||
# Synchronize local data with remote data
|
||||
def sync_local_data_with_remote(local_data, remote_data):
|
||||
# Update existing keys in local_data with values from remote_data
|
||||
for key in set(local_data) & set(remote_data):
|
||||
for key in (set(local_data) & set(remote_data)):
|
||||
local_data[key].update(remote_data[key])
|
||||
|
||||
# Add new keys from remote_data to local_data
|
||||
for key in set(remote_data) - set(local_data):
|
||||
for key in (set(remote_data) - set(local_data)):
|
||||
local_data[key] = remote_data[key]
|
||||
|
||||
|
||||
# Write data to the json file
|
||||
def write_to_file(file_path, data):
|
||||
try:
|
||||
@ -46,7 +43,6 @@ def write_to_file(file_path, data):
|
||||
# Print an error message if writing to file fails
|
||||
print("Error updating JSON file:", e)
|
||||
|
||||
|
||||
# Update the existing models and add the missing models for OpenRouter
|
||||
def transform_openrouter_data(data):
|
||||
transformed = {}
|
||||
@ -58,41 +54,33 @@ def transform_openrouter_data(data):
|
||||
}
|
||||
|
||||
# Add 'max_output_tokens' as a field if it is not None
|
||||
if (
|
||||
"top_provider" in row
|
||||
and "max_completion_tokens" in row["top_provider"]
|
||||
and row["top_provider"]["max_completion_tokens"] is not None
|
||||
):
|
||||
obj["max_output_tokens"] = int(row["top_provider"]["max_completion_tokens"])
|
||||
if "top_provider" in row and "max_completion_tokens" in row["top_provider"] and row["top_provider"]["max_completion_tokens"] is not None:
|
||||
obj['max_output_tokens'] = int(row["top_provider"]["max_completion_tokens"])
|
||||
|
||||
# Add the field 'output_cost_per_token'
|
||||
obj.update(
|
||||
{
|
||||
"output_cost_per_token": float(row["pricing"]["completion"]),
|
||||
}
|
||||
)
|
||||
obj.update({
|
||||
"output_cost_per_token": float(row["pricing"]["completion"]),
|
||||
})
|
||||
|
||||
# Add field 'input_cost_per_image' if it exists and is non-zero
|
||||
if (
|
||||
"pricing" in row
|
||||
and "image" in row["pricing"]
|
||||
and float(row["pricing"]["image"]) != 0.0
|
||||
):
|
||||
obj["input_cost_per_image"] = float(row["pricing"]["image"])
|
||||
if "pricing" in row and "image" in row["pricing"] and float(row["pricing"]["image"]) != 0.0:
|
||||
obj['input_cost_per_image'] = float(row["pricing"]["image"])
|
||||
|
||||
# Add the fields 'litellm_provider' and 'mode'
|
||||
obj.update({"litellm_provider": "openrouter", "mode": "chat"})
|
||||
obj.update({
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
})
|
||||
|
||||
# Add the 'supports_vision' field if the modality is 'multimodal'
|
||||
if row.get("architecture", {}).get("modality") == "multimodal":
|
||||
obj["supports_vision"] = True
|
||||
if row.get('architecture', {}).get('modality') == 'multimodal':
|
||||
obj['supports_vision'] = True
|
||||
|
||||
# Use a composite key to store the transformed object
|
||||
transformed[f'openrouter/{row["id"]}'] = obj
|
||||
|
||||
return transformed
|
||||
|
||||
|
||||
# Update the existing models and add the missing models for Vercel AI Gateway
|
||||
def transform_vercel_ai_gateway_data(data):
|
||||
transformed = {}
|
||||
@ -101,30 +89,20 @@ def transform_vercel_ai_gateway_data(data):
|
||||
"max_tokens": row["context_window"],
|
||||
"input_cost_per_token": float(row["pricing"]["input"]),
|
||||
"output_cost_per_token": float(row["pricing"]["output"]),
|
||||
"max_output_tokens": row["max_tokens"],
|
||||
"max_input_tokens": row["context_window"],
|
||||
'max_output_tokens': row['max_tokens'],
|
||||
'max_input_tokens': row["context_window"],
|
||||
}
|
||||
|
||||
# Handle cache pricing if available
|
||||
if "pricing" in row:
|
||||
if (
|
||||
"input_cache_read" in row["pricing"]
|
||||
and row["pricing"]["input_cache_read"] is not None
|
||||
):
|
||||
obj["cache_read_input_token_cost"] = float(
|
||||
f"{float(row['pricing']['input_cache_read']):e}"
|
||||
)
|
||||
|
||||
if (
|
||||
"input_cache_write" in row["pricing"]
|
||||
and row["pricing"]["input_cache_write"] is not None
|
||||
):
|
||||
obj["cache_creation_input_token_cost"] = float(
|
||||
f"{float(row['pricing']['input_cache_write']):e}"
|
||||
)
|
||||
if "input_cache_read" in row["pricing"] and row["pricing"]["input_cache_read"] is not None:
|
||||
obj['cache_read_input_token_cost'] = float(f"{float(row['pricing']['input_cache_read']):e}")
|
||||
|
||||
if "input_cache_write" in row["pricing"] and row["pricing"]["input_cache_write"] is not None:
|
||||
obj['cache_creation_input_token_cost'] = float(f"{float(row['pricing']['input_cache_write']):e}")
|
||||
|
||||
mode = "embedding" if "embedding" in row["id"].lower() else "chat"
|
||||
|
||||
|
||||
obj.update({"litellm_provider": "vercel_ai_gateway", "mode": mode})
|
||||
|
||||
transformed[f'vercel_ai_gateway/{row["id"]}'] = obj
|
||||
@ -148,31 +126,24 @@ def load_local_data(file_path):
|
||||
print("Error decoding JSON:", e)
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
local_file_path = (
|
||||
"model_prices_and_context_window.json" # Path to the local data file
|
||||
)
|
||||
openrouter_url = (
|
||||
"https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
|
||||
)
|
||||
vercel_ai_gateway_url = (
|
||||
"https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
|
||||
)
|
||||
local_file_path = "model_prices_and_context_window.json" # Path to the local data file
|
||||
openrouter_url = "https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data
|
||||
vercel_ai_gateway_url = "https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data
|
||||
|
||||
# Load local data from file
|
||||
local_data = load_local_data(local_file_path)
|
||||
|
||||
|
||||
# Fetch OpenRouter data
|
||||
openrouter_data = asyncio.run(fetch_data(openrouter_url))
|
||||
# Transform the fetched OpenRouter data
|
||||
openrouter_data = transform_openrouter_data(openrouter_data)
|
||||
|
||||
|
||||
# Fetch Vercel AI Gateway data
|
||||
vercel_data = asyncio.run(fetch_data(vercel_ai_gateway_url))
|
||||
# Transform the fetched Vercel AI Gateway data
|
||||
vercel_data = transform_vercel_ai_gateway_data(vercel_data)
|
||||
|
||||
|
||||
# Combine both datasets
|
||||
all_remote_data = {**openrouter_data, **vercel_data}
|
||||
|
||||
@ -183,7 +154,6 @@ def main():
|
||||
else:
|
||||
print("Failed to fetch model data from either local file or URL.")
|
||||
|
||||
|
||||
# Entry point of the script
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
512
.github/workflows/run_llm_translation_tests.py
vendored
512
.github/workflows/run_llm_translation_tests.py
vendored
@ -16,75 +16,64 @@ from pathlib import Path
|
||||
import json
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
class Colors:
|
||||
GREEN = "\033[92m"
|
||||
RED = "\033[91m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
PURPLE = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
PURPLE = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def print_colored(message: str, color: str = Colors.RESET):
|
||||
"""Print colored message to terminal"""
|
||||
print(f"{color}{message}{Colors.RESET}")
|
||||
|
||||
|
||||
def get_provider_from_test_file(test_file: str) -> str:
|
||||
"""Map test file names to provider names"""
|
||||
provider_mapping = {
|
||||
"test_anthropic": "Anthropic",
|
||||
"test_azure": "Azure",
|
||||
"test_bedrock": "AWS Bedrock",
|
||||
"test_openai": "OpenAI",
|
||||
"test_vertex": "Google Vertex AI",
|
||||
"test_gemini": "Google Vertex AI",
|
||||
"test_cohere": "Cohere",
|
||||
"test_databricks": "Databricks",
|
||||
"test_groq": "Groq",
|
||||
"test_together": "Together AI",
|
||||
"test_mistral": "Mistral",
|
||||
"test_deepseek": "DeepSeek",
|
||||
"test_replicate": "Replicate",
|
||||
"test_huggingface": "HuggingFace",
|
||||
"test_fireworks": "Fireworks AI",
|
||||
"test_perplexity": "Perplexity",
|
||||
"test_cloudflare": "Cloudflare",
|
||||
"test_voyage": "Voyage AI",
|
||||
"test_xai": "xAI",
|
||||
"test_nvidia": "NVIDIA",
|
||||
"test_watsonx": "IBM watsonx",
|
||||
"test_azure_ai": "Azure AI",
|
||||
"test_snowflake": "Snowflake",
|
||||
"test_infinity": "Infinity",
|
||||
"test_jina": "Jina AI",
|
||||
"test_deepgram": "Deepgram",
|
||||
"test_clarifai": "Clarifai",
|
||||
"test_triton": "Triton",
|
||||
'test_anthropic': 'Anthropic',
|
||||
'test_azure': 'Azure',
|
||||
'test_bedrock': 'AWS Bedrock',
|
||||
'test_openai': 'OpenAI',
|
||||
'test_vertex': 'Google Vertex AI',
|
||||
'test_gemini': 'Google Vertex AI',
|
||||
'test_cohere': 'Cohere',
|
||||
'test_databricks': 'Databricks',
|
||||
'test_groq': 'Groq',
|
||||
'test_together': 'Together AI',
|
||||
'test_mistral': 'Mistral',
|
||||
'test_deepseek': 'DeepSeek',
|
||||
'test_replicate': 'Replicate',
|
||||
'test_huggingface': 'HuggingFace',
|
||||
'test_fireworks': 'Fireworks AI',
|
||||
'test_perplexity': 'Perplexity',
|
||||
'test_cloudflare': 'Cloudflare',
|
||||
'test_voyage': 'Voyage AI',
|
||||
'test_xai': 'xAI',
|
||||
'test_nvidia': 'NVIDIA',
|
||||
'test_watsonx': 'IBM watsonx',
|
||||
'test_azure_ai': 'Azure AI',
|
||||
'test_snowflake': 'Snowflake',
|
||||
'test_infinity': 'Infinity',
|
||||
'test_jina': 'Jina AI',
|
||||
'test_deepgram': 'Deepgram',
|
||||
'test_clarifai': 'Clarifai',
|
||||
'test_triton': 'Triton',
|
||||
}
|
||||
|
||||
|
||||
for key, provider in provider_mapping.items():
|
||||
if key in test_file:
|
||||
return provider
|
||||
|
||||
|
||||
# For cross-provider test files
|
||||
if any(
|
||||
name in test_file
|
||||
for name in [
|
||||
"test_optional_params",
|
||||
"test_prompt_factory",
|
||||
"test_router",
|
||||
"test_text_completion",
|
||||
]
|
||||
):
|
||||
return f"Cross-Provider Tests ({test_file})"
|
||||
|
||||
return "Other Tests"
|
||||
|
||||
if any(name in test_file for name in ['test_optional_params', 'test_prompt_factory',
|
||||
'test_router', 'test_text_completion']):
|
||||
return f'Cross-Provider Tests ({test_file})'
|
||||
|
||||
return 'Other Tests'
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
"""Format duration in human-readable format"""
|
||||
@ -100,355 +89,290 @@ def format_duration(seconds: float) -> str:
|
||||
return f"{hours}h {minutes}m"
|
||||
|
||||
|
||||
def generate_markdown_report(
|
||||
junit_xml_path: str, output_path: str, tag: str = None, commit: str = None
|
||||
):
|
||||
def generate_markdown_report(junit_xml_path: str, output_path: str, tag: str = None, commit: str = None):
|
||||
"""Generate a beautiful markdown report from JUnit XML"""
|
||||
try:
|
||||
tree = ET.parse(junit_xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
|
||||
# Handle both testsuite and testsuites root
|
||||
if root.tag == "testsuites":
|
||||
suites = root.findall("testsuite")
|
||||
if root.tag == 'testsuites':
|
||||
suites = root.findall('testsuite')
|
||||
else:
|
||||
suites = [root]
|
||||
|
||||
|
||||
# Overall statistics
|
||||
total_tests = 0
|
||||
total_failures = 0
|
||||
total_errors = 0
|
||||
total_skipped = 0
|
||||
total_time = 0.0
|
||||
|
||||
|
||||
# Provider breakdown
|
||||
provider_stats = defaultdict(
|
||||
lambda: {"passed": 0, "failed": 0, "skipped": 0, "errors": 0, "time": 0.0}
|
||||
)
|
||||
provider_stats = defaultdict(lambda: {'passed': 0, 'failed': 0, 'skipped': 0, 'errors': 0, 'time': 0.0})
|
||||
provider_tests = defaultdict(list)
|
||||
|
||||
|
||||
for suite in suites:
|
||||
total_tests += int(suite.get("tests", 0))
|
||||
total_failures += int(suite.get("failures", 0))
|
||||
total_errors += int(suite.get("errors", 0))
|
||||
total_skipped += int(suite.get("skipped", 0))
|
||||
total_time += float(suite.get("time", 0))
|
||||
|
||||
for testcase in suite.findall("testcase"):
|
||||
classname = testcase.get("classname", "")
|
||||
test_name = testcase.get("name", "")
|
||||
test_time = float(testcase.get("time", 0))
|
||||
|
||||
total_tests += int(suite.get('tests', 0))
|
||||
total_failures += int(suite.get('failures', 0))
|
||||
total_errors += int(suite.get('errors', 0))
|
||||
total_skipped += int(suite.get('skipped', 0))
|
||||
total_time += float(suite.get('time', 0))
|
||||
|
||||
for testcase in suite.findall('testcase'):
|
||||
classname = testcase.get('classname', '')
|
||||
test_name = testcase.get('name', '')
|
||||
test_time = float(testcase.get('time', 0))
|
||||
|
||||
# Extract test file name from classname
|
||||
if "." in classname:
|
||||
parts = classname.split(".")
|
||||
test_file = parts[-2] if len(parts) > 1 else "unknown"
|
||||
if '.' in classname:
|
||||
parts = classname.split('.')
|
||||
test_file = parts[-2] if len(parts) > 1 else 'unknown'
|
||||
else:
|
||||
test_file = "unknown"
|
||||
|
||||
test_file = 'unknown'
|
||||
|
||||
provider = get_provider_from_test_file(test_file)
|
||||
provider_stats[provider]["time"] += test_time
|
||||
|
||||
provider_stats[provider]['time'] += test_time
|
||||
|
||||
# Check test status
|
||||
if testcase.find("failure") is not None:
|
||||
provider_stats[provider]["failed"] += 1
|
||||
failure = testcase.find("failure")
|
||||
failure_msg = (
|
||||
failure.get("message", "") if failure is not None else ""
|
||||
)
|
||||
provider_tests[provider].append(
|
||||
{
|
||||
"name": test_name,
|
||||
"status": "FAILED",
|
||||
"time": test_time,
|
||||
"message": failure_msg,
|
||||
}
|
||||
)
|
||||
elif testcase.find("error") is not None:
|
||||
provider_stats[provider]["errors"] += 1
|
||||
error = testcase.find("error")
|
||||
error_msg = error.get("message", "") if error is not None else ""
|
||||
provider_tests[provider].append(
|
||||
{
|
||||
"name": test_name,
|
||||
"status": "ERROR",
|
||||
"time": test_time,
|
||||
"message": error_msg,
|
||||
}
|
||||
)
|
||||
elif testcase.find("skipped") is not None:
|
||||
provider_stats[provider]["skipped"] += 1
|
||||
skip = testcase.find("skipped")
|
||||
skip_msg = skip.get("message", "") if skip is not None else ""
|
||||
provider_tests[provider].append(
|
||||
{
|
||||
"name": test_name,
|
||||
"status": "SKIPPED",
|
||||
"time": test_time,
|
||||
"message": skip_msg,
|
||||
}
|
||||
)
|
||||
if testcase.find('failure') is not None:
|
||||
provider_stats[provider]['failed'] += 1
|
||||
failure = testcase.find('failure')
|
||||
failure_msg = failure.get('message', '') if failure is not None else ''
|
||||
provider_tests[provider].append({
|
||||
'name': test_name,
|
||||
'status': 'FAILED',
|
||||
'time': test_time,
|
||||
'message': failure_msg
|
||||
})
|
||||
elif testcase.find('error') is not None:
|
||||
provider_stats[provider]['errors'] += 1
|
||||
error = testcase.find('error')
|
||||
error_msg = error.get('message', '') if error is not None else ''
|
||||
provider_tests[provider].append({
|
||||
'name': test_name,
|
||||
'status': 'ERROR',
|
||||
'time': test_time,
|
||||
'message': error_msg
|
||||
})
|
||||
elif testcase.find('skipped') is not None:
|
||||
provider_stats[provider]['skipped'] += 1
|
||||
skip = testcase.find('skipped')
|
||||
skip_msg = skip.get('message', '') if skip is not None else ''
|
||||
provider_tests[provider].append({
|
||||
'name': test_name,
|
||||
'status': 'SKIPPED',
|
||||
'time': test_time,
|
||||
'message': skip_msg
|
||||
})
|
||||
else:
|
||||
provider_stats[provider]["passed"] += 1
|
||||
provider_tests[provider].append(
|
||||
{
|
||||
"name": test_name,
|
||||
"status": "PASSED",
|
||||
"time": test_time,
|
||||
"message": "",
|
||||
}
|
||||
)
|
||||
|
||||
provider_stats[provider]['passed'] += 1
|
||||
provider_tests[provider].append({
|
||||
'name': test_name,
|
||||
'status': 'PASSED',
|
||||
'time': test_time,
|
||||
'message': ''
|
||||
})
|
||||
|
||||
passed = total_tests - total_failures - total_errors - total_skipped
|
||||
|
||||
|
||||
# Generate the markdown report
|
||||
with open(output_path, "w") as f:
|
||||
with open(output_path, 'w') as f:
|
||||
# Header
|
||||
f.write("# LLM Translation Test Results\n\n")
|
||||
|
||||
|
||||
# Metadata table
|
||||
f.write("## Test Run Information\n\n")
|
||||
f.write("| Field | Value |\n")
|
||||
f.write("|-------|-------|\n")
|
||||
f.write(f"| **Tag** | `{tag or 'N/A'}` |\n")
|
||||
f.write(
|
||||
f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n"
|
||||
)
|
||||
f.write(f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n")
|
||||
f.write(f"| **Commit** | `{commit or 'N/A'}` |\n")
|
||||
f.write(f"| **Duration** | {format_duration(total_time)} |\n")
|
||||
f.write("\n")
|
||||
|
||||
|
||||
# Overall statistics with visual elements
|
||||
f.write("## Overall Statistics\n\n")
|
||||
|
||||
|
||||
# Summary box
|
||||
f.write("```\n")
|
||||
f.write(f"Total Tests: {total_tests}\n")
|
||||
f.write(
|
||||
f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
||||
)
|
||||
f.write(
|
||||
f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
||||
)
|
||||
f.write(
|
||||
f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
||||
)
|
||||
f.write(
|
||||
f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n"
|
||||
)
|
||||
f.write(f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||
f.write(f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||
f.write(f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||
f.write(f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n")
|
||||
f.write("```\n\n")
|
||||
|
||||
|
||||
|
||||
# Provider summary table
|
||||
f.write("## Results by Provider\n\n")
|
||||
f.write(
|
||||
"| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n"
|
||||
)
|
||||
f.write(
|
||||
"|----------|-------|------|------|-------|------|-----------|----------|"
|
||||
)
|
||||
|
||||
f.write("| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n")
|
||||
f.write("|----------|-------|------|------|-------|------|-----------|----------|")
|
||||
|
||||
# Sort providers: specific providers first, then cross-provider tests
|
||||
sorted_providers = []
|
||||
cross_provider = []
|
||||
for p in sorted(provider_stats.keys()):
|
||||
if "Cross-Provider" in p or p == "Other Tests":
|
||||
if 'Cross-Provider' in p or p == 'Other Tests':
|
||||
cross_provider.append(p)
|
||||
else:
|
||||
sorted_providers.append(p)
|
||||
|
||||
|
||||
all_providers = sorted_providers + cross_provider
|
||||
|
||||
|
||||
for provider in all_providers:
|
||||
stats = provider_stats[provider]
|
||||
total = (
|
||||
stats["passed"]
|
||||
+ stats["failed"]
|
||||
+ stats["errors"]
|
||||
+ stats["skipped"]
|
||||
)
|
||||
pass_rate = (stats["passed"] / total * 100) if total > 0 else 0
|
||||
|
||||
f.write(
|
||||
f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | "
|
||||
)
|
||||
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||
pass_rate = (stats['passed'] / total * 100) if total > 0 else 0
|
||||
|
||||
f.write(f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | ")
|
||||
f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ")
|
||||
f.write(f"{format_duration(stats['time'])} |")
|
||||
|
||||
|
||||
# Detailed test results by provider
|
||||
f.write("\n\n## Detailed Test Results\n\n")
|
||||
|
||||
|
||||
for provider in sorted_providers:
|
||||
if provider_tests[provider]:
|
||||
stats = provider_stats[provider]
|
||||
total = (
|
||||
stats["passed"]
|
||||
+ stats["failed"]
|
||||
+ stats["errors"]
|
||||
+ stats["skipped"]
|
||||
)
|
||||
|
||||
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||
|
||||
f.write(f"### {provider}\n\n")
|
||||
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
|
||||
f.write(
|
||||
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) "
|
||||
)
|
||||
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) ")
|
||||
f.write(f"in {format_duration(stats['time'])}\n\n")
|
||||
|
||||
|
||||
# Group tests by status
|
||||
tests_by_status = defaultdict(list)
|
||||
for test in provider_tests[provider]:
|
||||
tests_by_status[test["status"]].append(test)
|
||||
|
||||
tests_by_status[test['status']].append(test)
|
||||
|
||||
# Show failed tests first (if any)
|
||||
if tests_by_status["FAILED"]:
|
||||
if tests_by_status['FAILED']:
|
||||
f.write("<details>\n<summary>Failed Tests</summary>\n\n")
|
||||
for test in tests_by_status["FAILED"]:
|
||||
for test in tests_by_status['FAILED']:
|
||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||
if test["message"]:
|
||||
if test['message']:
|
||||
# Truncate long error messages
|
||||
msg = (
|
||||
test["message"][:200] + "..."
|
||||
if len(test["message"]) > 200
|
||||
else test["message"]
|
||||
)
|
||||
msg = test['message'][:200] + '...' if len(test['message']) > 200 else test['message']
|
||||
f.write(f" > {msg}\n")
|
||||
f.write("\n</details>\n\n")
|
||||
|
||||
|
||||
# Show errors (if any)
|
||||
if tests_by_status["ERROR"]:
|
||||
if tests_by_status['ERROR']:
|
||||
f.write("<details>\n<summary>Error Tests</summary>\n\n")
|
||||
for test in tests_by_status["ERROR"]:
|
||||
for test in tests_by_status['ERROR']:
|
||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||
f.write("\n</details>\n\n")
|
||||
|
||||
|
||||
# Show passed tests in collapsible section
|
||||
if tests_by_status["PASSED"]:
|
||||
if tests_by_status['PASSED']:
|
||||
f.write("<details>\n<summary>Passed Tests</summary>\n\n")
|
||||
for test in tests_by_status["PASSED"]:
|
||||
for test in tests_by_status['PASSED']:
|
||||
f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n")
|
||||
f.write("\n</details>\n\n")
|
||||
|
||||
|
||||
# Show skipped tests (if any)
|
||||
if tests_by_status["SKIPPED"]:
|
||||
if tests_by_status['SKIPPED']:
|
||||
f.write("<details>\n<summary>Skipped Tests</summary>\n\n")
|
||||
for test in tests_by_status["SKIPPED"]:
|
||||
for test in tests_by_status['SKIPPED']:
|
||||
f.write(f"- `{test['name']}`\n")
|
||||
f.write("\n</details>\n\n")
|
||||
|
||||
|
||||
# Cross-provider tests in a separate section
|
||||
if cross_provider:
|
||||
f.write("### Cross-Provider Tests\n\n")
|
||||
for provider in cross_provider:
|
||||
if provider_tests[provider]:
|
||||
stats = provider_stats[provider]
|
||||
total = (
|
||||
stats["passed"]
|
||||
+ stats["failed"]
|
||||
+ stats["errors"]
|
||||
+ stats["skipped"]
|
||||
)
|
||||
|
||||
total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped']
|
||||
|
||||
f.write(f"#### {provider}\n\n")
|
||||
f.write(f"**Summary:** {stats['passed']}/{total} passed ")
|
||||
f.write(
|
||||
f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n"
|
||||
)
|
||||
|
||||
f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n")
|
||||
|
||||
# For cross-provider tests, just show counts
|
||||
f.write(f"- Passed: {stats['passed']}\n")
|
||||
if stats["failed"] > 0:
|
||||
if stats['failed'] > 0:
|
||||
f.write(f"- Failed: {stats['failed']}\n")
|
||||
if stats["errors"] > 0:
|
||||
if stats['errors'] > 0:
|
||||
f.write(f"- Errors: {stats['errors']}\n")
|
||||
if stats["skipped"] > 0:
|
||||
if stats['skipped'] > 0:
|
||||
f.write(f"- Skipped: {stats['skipped']}\n")
|
||||
f.write("\n")
|
||||
|
||||
|
||||
|
||||
print_colored(f"Report generated: {output_path}", Colors.GREEN)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print_colored(f"Error generating report: {e}", Colors.RED)
|
||||
raise
|
||||
|
||||
|
||||
def run_tests(
|
||||
test_path: str = "tests/llm_translation/",
|
||||
junit_xml: str = "test-results/junit.xml",
|
||||
report_path: str = "test-results/llm_translation_report.md",
|
||||
tag: str = None,
|
||||
commit: str = None,
|
||||
) -> int:
|
||||
def run_tests(test_path: str = "tests/llm_translation/",
|
||||
junit_xml: str = "test-results/junit.xml",
|
||||
report_path: str = "test-results/llm_translation_report.md",
|
||||
tag: str = None,
|
||||
commit: str = None) -> int:
|
||||
"""Run the LLM translation tests and generate report"""
|
||||
|
||||
|
||||
# Create test results directory
|
||||
os.makedirs(os.path.dirname(junit_xml), exist_ok=True)
|
||||
|
||||
|
||||
print_colored("Starting LLM Translation Tests", Colors.BOLD + Colors.BLUE)
|
||||
print_colored(f"Test directory: {test_path}", Colors.CYAN)
|
||||
print_colored(f"Output: {junit_xml}", Colors.CYAN)
|
||||
print()
|
||||
|
||||
|
||||
# Run pytest
|
||||
cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--no-sync",
|
||||
"pytest",
|
||||
test_path,
|
||||
"uv", "run", "--no-sync", "pytest", test_path,
|
||||
f"--junitxml={junit_xml}",
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"--maxfail=500",
|
||||
"-n",
|
||||
"auto",
|
||||
"-n", "auto"
|
||||
]
|
||||
|
||||
|
||||
# Add timeout if pytest-timeout is installed
|
||||
try:
|
||||
subprocess.run(
|
||||
["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
subprocess.run(["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"],
|
||||
capture_output=True, check=True)
|
||||
cmd.extend(["--timeout=300"])
|
||||
except:
|
||||
print_colored(
|
||||
"Warning: pytest-timeout not installed, skipping timeout option",
|
||||
Colors.YELLOW,
|
||||
)
|
||||
|
||||
print_colored("Warning: pytest-timeout not installed, skipping timeout option", Colors.YELLOW)
|
||||
|
||||
print_colored("Running pytest with command:", Colors.YELLOW)
|
||||
print(f" {' '.join(cmd)}")
|
||||
print()
|
||||
|
||||
|
||||
# Run the tests
|
||||
result = subprocess.run(cmd, capture_output=False)
|
||||
|
||||
|
||||
# Generate the report regardless of test outcome
|
||||
if os.path.exists(junit_xml):
|
||||
print()
|
||||
print_colored("Generating test report...", Colors.BLUE)
|
||||
generate_markdown_report(junit_xml, report_path, tag, commit)
|
||||
|
||||
|
||||
# Print summary to console
|
||||
print()
|
||||
print_colored("Test Summary:", Colors.BOLD + Colors.PURPLE)
|
||||
|
||||
|
||||
# Parse XML for quick summary
|
||||
tree = ET.parse(junit_xml)
|
||||
root = tree.getroot()
|
||||
|
||||
if root.tag == "testsuites":
|
||||
suites = root.findall("testsuite")
|
||||
|
||||
if root.tag == 'testsuites':
|
||||
suites = root.findall('testsuite')
|
||||
else:
|
||||
suites = [root]
|
||||
|
||||
total = sum(int(s.get("tests", 0)) for s in suites)
|
||||
failures = sum(int(s.get("failures", 0)) for s in suites)
|
||||
errors = sum(int(s.get("errors", 0)) for s in suites)
|
||||
skipped = sum(int(s.get("skipped", 0)) for s in suites)
|
||||
|
||||
total = sum(int(s.get('tests', 0)) for s in suites)
|
||||
failures = sum(int(s.get('failures', 0)) for s in suites)
|
||||
errors = sum(int(s.get('errors', 0)) for s in suites)
|
||||
skipped = sum(int(s.get('skipped', 0)) for s in suites)
|
||||
passed = total - failures - errors - skipped
|
||||
|
||||
|
||||
print(f" Total: {total}")
|
||||
print_colored(f" Passed: {passed}", Colors.GREEN)
|
||||
if failures > 0:
|
||||
@ -457,75 +381,59 @@ def run_tests(
|
||||
print_colored(f" Errors: {errors}", Colors.RED)
|
||||
if skipped > 0:
|
||||
print_colored(f" Skipped: {skipped}", Colors.YELLOW)
|
||||
|
||||
|
||||
if total > 0:
|
||||
pass_rate = (passed / total) * 100
|
||||
color = (
|
||||
Colors.GREEN
|
||||
if pass_rate >= 80
|
||||
else Colors.YELLOW if pass_rate >= 60 else Colors.RED
|
||||
)
|
||||
color = Colors.GREEN if pass_rate >= 80 else Colors.YELLOW if pass_rate >= 60 else Colors.RED
|
||||
print_colored(f" Pass Rate: {pass_rate:.1f}%", color)
|
||||
else:
|
||||
print_colored("No test results found!", Colors.RED)
|
||||
|
||||
|
||||
print()
|
||||
print_colored("Test run complete!", Colors.BOLD + Colors.GREEN)
|
||||
|
||||
|
||||
return result.returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run LLM Translation Tests")
|
||||
parser.add_argument(
|
||||
"--test-path", default="tests/llm_translation/", help="Path to test directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--junit-xml",
|
||||
default="test-results/junit.xml",
|
||||
help="Path for JUnit XML output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report",
|
||||
default="test-results/llm_translation_report.md",
|
||||
help="Path for markdown report",
|
||||
)
|
||||
parser.add_argument("--test-path", default="tests/llm_translation/",
|
||||
help="Path to test directory")
|
||||
parser.add_argument("--junit-xml", default="test-results/junit.xml",
|
||||
help="Path for JUnit XML output")
|
||||
parser.add_argument("--report", default="test-results/llm_translation_report.md",
|
||||
help="Path for markdown report")
|
||||
parser.add_argument("--tag", help="Git tag or version")
|
||||
parser.add_argument("--commit", help="Git commit SHA")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Get git info if not provided
|
||||
if not args.commit:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"], capture_output=True, text=True
|
||||
)
|
||||
result = subprocess.run(["git", "rev-parse", "HEAD"],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
args.commit = result.stdout.strip()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if not args.tag:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "describe", "--tags", "--abbrev=0"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
args.tag = result.stdout.strip()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
exit_code = run_tests(
|
||||
test_path=args.test_path,
|
||||
junit_xml=args.junit_xml,
|
||||
report_path=args.report,
|
||||
tag=args.tag,
|
||||
commit=args.commit,
|
||||
commit=args.commit
|
||||
)
|
||||
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
@ -9,6 +9,7 @@ from pathlib import Path
|
||||
|
||||
import testing.postgresql
|
||||
|
||||
|
||||
DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE)
|
||||
DEFAULT_BASE_BRANCH = "litellm_internal_staging"
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
import os
|
||||
|
||||
|
||||
# Define the list of models to benchmark
|
||||
# select any LLM listed here: https://docs.litellm.ai/docs/providers
|
||||
models = ["gpt-3.5-turbo", "claude-2"]
|
||||
|
||||
@ -22,6 +22,7 @@ from fastapi import FastAPI, HTTPException, Header, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Braintrust Prompt Wrapper",
|
||||
description="Wrapper server for Braintrust prompts to work with LiteLLM",
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Simple xAI Voice Agent using LiveKit SDK with LiteLLM Gateway
|
||||
|
||||
This example shows how to use LiveKit's xAI realtime plugin through LiteLLM proxy.
|
||||
LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI,
|
||||
LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI,
|
||||
and Azure realtime APIs without changing your agent code.
|
||||
"""
|
||||
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
"""
|
||||
LiteLLM Migration Script!
|
||||
|
||||
Takes a config.yaml and calls /model/new
|
||||
Takes a config.yaml and calls /model/new
|
||||
|
||||
Inputs:
|
||||
- File path to config.yaml
|
||||
- Proxy base url to your hosted proxy
|
||||
|
||||
Step 1: Reads your config.yaml
|
||||
Step 2: reads `model_list` and loops through all models
|
||||
Step 2: reads `model_list` and loops through all models
|
||||
Step 3: calls `<proxy-base-url>/model/new` for each model
|
||||
"""
|
||||
|
||||
|
||||
@ -518,7 +518,8 @@ if __name__ == "__main__":
|
||||
print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply")
|
||||
print("=" * 80)
|
||||
print("\nExample curl command:")
|
||||
print(f"""
|
||||
print(
|
||||
f"""
|
||||
curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
|
||||
-H "Authorization: Bearer {bearer_token}" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
@ -532,7 +533,8 @@ curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\
|
||||
}}
|
||||
]
|
||||
}}'
|
||||
""")
|
||||
"""
|
||||
)
|
||||
print("=" * 80)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
@ -34,7 +34,8 @@ async def check_view_exists(): # noqa: PLR0915
|
||||
print("LiteLLM_VerificationTokenView Exists!") # noqa
|
||||
except Exception:
|
||||
# If an error occurs, the view does not exist, so create it
|
||||
await db.execute_raw("""
|
||||
await db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
@ -44,7 +45,8 @@ async def check_view_exists(): # noqa: PLR0915
|
||||
t.rpm_limit AS team_rpm_limit
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
print("LiteLLM_VerificationTokenView Created!") # noqa
|
||||
|
||||
|
||||
@ -14,74 +14,53 @@ from litellm.types.utils import StandardCallbackDynamicParams
|
||||
class EnterpriseCallbackControls:
|
||||
@staticmethod
|
||||
def is_callback_disabled_dynamically(
|
||||
callback: litellm.CALLBACK_TYPES,
|
||||
litellm_params: dict,
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
|
||||
|
||||
Args:
|
||||
callback: The callback to check (can be string, CustomLogger instance, or callable)
|
||||
litellm_params: Parameters containing proxy server request info
|
||||
|
||||
Returns:
|
||||
bool: True if the callback should be disabled, False otherwise
|
||||
"""
|
||||
from litellm.litellm_core_utils.custom_logger_registry import (
|
||||
CustomLoggerRegistry,
|
||||
)
|
||||
|
||||
try:
|
||||
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(
|
||||
litellm_params, standard_callback_dynamic_params
|
||||
callback: litellm.CALLBACK_TYPES,
|
||||
litellm_params: dict,
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
|
||||
|
||||
Args:
|
||||
callback: The callback to check (can be string, CustomLogger instance, or callable)
|
||||
litellm_params: Parameters containing proxy server request info
|
||||
|
||||
Returns:
|
||||
bool: True if the callback should be disabled, False otherwise
|
||||
"""
|
||||
from litellm.litellm_core_utils.custom_logger_registry import (
|
||||
CustomLoggerRegistry,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}"
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}"
|
||||
)
|
||||
if disabled_callbacks is not None:
|
||||
#########################################################
|
||||
# premium user check
|
||||
#########################################################
|
||||
if (
|
||||
not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling()
|
||||
):
|
||||
return False
|
||||
#########################################################
|
||||
if isinstance(callback, str):
|
||||
if callback.lower() in disabled_callbacks:
|
||||
verbose_logger.debug(
|
||||
f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
|
||||
)
|
||||
return True
|
||||
elif isinstance(callback, CustomLogger):
|
||||
# get the string name of the callback
|
||||
callback_str = (
|
||||
CustomLoggerRegistry.get_callback_str_from_class_type(
|
||||
callback.__class__
|
||||
)
|
||||
)
|
||||
if (
|
||||
callback_str is not None
|
||||
and callback_str.lower() in disabled_callbacks
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error checking disabled callbacks header: {str(e)}")
|
||||
return False
|
||||
|
||||
try:
|
||||
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
|
||||
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
|
||||
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
|
||||
if disabled_callbacks is not None:
|
||||
#########################################################
|
||||
# premium user check
|
||||
#########################################################
|
||||
if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling():
|
||||
return False
|
||||
#########################################################
|
||||
if isinstance(callback, str):
|
||||
if callback.lower() in disabled_callbacks:
|
||||
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
|
||||
return True
|
||||
elif isinstance(callback, CustomLogger):
|
||||
# get the string name of the callback
|
||||
callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__)
|
||||
if callback_str is not None and callback_str.lower() in disabled_callbacks:
|
||||
verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error checking disabled callbacks header: {str(e)}"
|
||||
)
|
||||
return False
|
||||
@staticmethod
|
||||
def get_disabled_callbacks(
|
||||
litellm_params: dict,
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
) -> Optional[List[str]]:
|
||||
def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
|
||||
"""
|
||||
Get the disabled callbacks from the standard callback dynamic params.
|
||||
"""
|
||||
@ -92,24 +71,18 @@ class EnterpriseCallbackControls:
|
||||
request_headers = get_proxy_server_request_headers(litellm_params)
|
||||
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
|
||||
if disabled_callbacks is not None:
|
||||
disabled_callbacks = set(
|
||||
[cb.strip().lower() for cb in disabled_callbacks.split(",")]
|
||||
)
|
||||
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
|
||||
return list(disabled_callbacks)
|
||||
|
||||
|
||||
#########################################################
|
||||
# check if disabled via request body
|
||||
#########################################################
|
||||
if (
|
||||
standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
|
||||
is not None
|
||||
):
|
||||
return standard_callback_dynamic_params.get(
|
||||
"litellm_disabled_callbacks", None
|
||||
)
|
||||
|
||||
if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
|
||||
return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _should_allow_dynamic_callback_disabling():
|
||||
import litellm
|
||||
@ -117,14 +90,10 @@ class EnterpriseCallbackControls:
|
||||
|
||||
# Check if admin has disabled this feature
|
||||
if litellm.allow_dynamic_callback_disabling is not True:
|
||||
verbose_logger.debug(
|
||||
"Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling"
|
||||
)
|
||||
verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling")
|
||||
return False
|
||||
|
||||
|
||||
if premium_user:
|
||||
return True
|
||||
verbose_logger.warning(
|
||||
f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return False
|
||||
verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
|
||||
return False
|
||||
@ -349,10 +349,8 @@ class BaseEmailLogger(CustomLogger):
|
||||
)
|
||||
|
||||
# Calculate percentage and alert threshold
|
||||
percentage = (
|
||||
threshold_pct
|
||||
if threshold_pct is not None
|
||||
else int(EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100)
|
||||
percentage = threshold_pct if threshold_pct is not None else int(
|
||||
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100
|
||||
)
|
||||
threshold_fraction = percentage / 100.0
|
||||
alert_threshold_str = (
|
||||
@ -611,7 +609,9 @@ class BaseEmailLogger(CustomLogger):
|
||||
continue
|
||||
|
||||
_id = user_info.token or user_info.user_id or "default_id"
|
||||
_cache_key = f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
|
||||
_cache_key = (
|
||||
f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}"
|
||||
)
|
||||
|
||||
result = await _cache.async_get_cache(key=_cache_key)
|
||||
if result is not None:
|
||||
@ -630,9 +630,7 @@ class BaseEmailLogger(CustomLogger):
|
||||
continue
|
||||
recipient_emails = list(set(emails))
|
||||
|
||||
event_message = (
|
||||
f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
|
||||
)
|
||||
event_message = f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached"
|
||||
webhook_event = WebhookEvent(
|
||||
event="max_budget_alert",
|
||||
event_message=event_message,
|
||||
|
||||
@ -15,6 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||
|
||||
from .base_email import BaseEmailLogger
|
||||
|
||||
|
||||
SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send"
|
||||
|
||||
|
||||
@ -78,4 +79,4 @@ class SendGridEmailLogger(BaseEmailLogger):
|
||||
verbose_logger.debug(
|
||||
f"SendGrid response status={response.status_code}, body={response.text}"
|
||||
)
|
||||
return
|
||||
return
|
||||
@ -1,7 +1,6 @@
|
||||
"""
|
||||
This is the litellm SMTP email integration
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Enterprise specific logging utils
|
||||
"""
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata
|
||||
|
||||
|
||||
|
||||
@ -153,11 +153,11 @@ async def get_audit_logs(
|
||||
|
||||
# Return paginated response
|
||||
return PaginatedAuditLogResponse(
|
||||
audit_logs=(
|
||||
[AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs]
|
||||
if audit_logs
|
||||
else []
|
||||
),
|
||||
audit_logs=[
|
||||
AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs
|
||||
]
|
||||
if audit_logs
|
||||
else [],
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
||||
@ -7,4 +7,4 @@ including custom SSO handlers and advanced authentication features.
|
||||
|
||||
from .custom_sso_handler import EnterpriseCustomSSOHandler
|
||||
|
||||
__all__ = ["EnterpriseCustomSSOHandler"]
|
||||
__all__ = ["EnterpriseCustomSSOHandler"]
|
||||
@ -53,9 +53,7 @@ class CheckBatchCost:
|
||||
"user_api_key_alias": getattr(user_row, "user_alias", None),
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}"
|
||||
)
|
||||
verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}")
|
||||
return {}
|
||||
|
||||
async def _cleanup_stale_managed_objects(self) -> None:
|
||||
@ -64,22 +62,11 @@ class CheckBatchCost:
|
||||
in non-terminal states as 'stale_expired'. These will never complete and
|
||||
should not be polled.
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(
|
||||
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
|
||||
)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
|
||||
result = await self.prisma_client.db.litellm_managedobjecttable.update_many(
|
||||
where={
|
||||
"file_purpose": "batch",
|
||||
"status": {
|
||||
"not_in": [
|
||||
"completed",
|
||||
"complete",
|
||||
"failed",
|
||||
"expired",
|
||||
"cancelled",
|
||||
"stale_expired",
|
||||
]
|
||||
},
|
||||
"status": {"not_in": ["completed", "complete", "failed", "expired", "cancelled", "stale_expired"]},
|
||||
"created_at": {"lt": cutoff},
|
||||
},
|
||||
data={"status": "stale_expired"},
|
||||
@ -133,12 +120,9 @@ class CheckBatchCost:
|
||||
|
||||
try:
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
prom_logger = PrometheusLogger.get_instance()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"CheckBatchCost: could not get Prometheus logger: {e}"
|
||||
)
|
||||
verbose_proxy_logger.error(f"CheckBatchCost: could not get Prometheus logger: {e}")
|
||||
prom_logger = None
|
||||
|
||||
processed_models: List[Tuple[Optional[str], Optional[str]]] = []
|
||||
@ -177,11 +161,7 @@ class CheckBatchCost:
|
||||
order={"created_at": "asc"},
|
||||
)
|
||||
except Exception as query_err:
|
||||
if (
|
||||
"batch_processed" not in str(query_err).lower()
|
||||
and "unknown column" not in str(query_err).lower()
|
||||
and "does not exist" not in str(query_err).lower()
|
||||
):
|
||||
if "batch_processed" not in str(query_err).lower() and "unknown column" not in str(query_err).lower() and "does not exist" not in str(query_err).lower():
|
||||
raise
|
||||
# Permanent schema gap — cache the result so future cycles skip straight to fallback
|
||||
self._has_batch_processed_column = False
|
||||
@ -236,13 +216,14 @@ class CheckBatchCost:
|
||||
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
|
||||
)
|
||||
if prom_logger:
|
||||
prom_logger.record_check_batch_cost_error(
|
||||
"provider_retrieval_error"
|
||||
)
|
||||
prom_logger.record_check_batch_cost_error("provider_retrieval_error")
|
||||
continue
|
||||
|
||||
## RETRIEVE THE BATCH JOB OUTPUT FILE
|
||||
if response.status == "completed" and response.output_file_id is not None:
|
||||
if (
|
||||
response.status == "completed"
|
||||
and response.output_file_id is not None
|
||||
):
|
||||
verbose_proxy_logger.info(
|
||||
f"Batch ID: {batch_id} is complete, tracking cost and usage"
|
||||
)
|
||||
@ -269,25 +250,20 @@ class CheckBatchCost:
|
||||
decoded = _is_base64_encoded_unified_file_id(raw_output_file_id)
|
||||
if decoded:
|
||||
try:
|
||||
raw_output_file_id = decoded.split("llm_output_file_id,")[
|
||||
1
|
||||
].split(";")[0]
|
||||
raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0]
|
||||
except (IndexError, AttributeError):
|
||||
pass
|
||||
|
||||
credentials = (
|
||||
self.llm_router.get_deployment_credentials_with_provider(model_id)
|
||||
or {}
|
||||
)
|
||||
credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {}
|
||||
_file_content = await afile_content(
|
||||
file_id=raw_output_file_id,
|
||||
**credentials,
|
||||
)
|
||||
|
||||
# Access content - handle both direct attribute and method call
|
||||
if hasattr(_file_content, "content"):
|
||||
if hasattr(_file_content, 'content'):
|
||||
content_bytes = _file_content.content # type: ignore[union-attr]
|
||||
elif hasattr(_file_content, "read"):
|
||||
elif hasattr(_file_content, 'read'):
|
||||
content_bytes = await _file_content.read() # type: ignore[misc]
|
||||
else:
|
||||
content_bytes = _file_content # type: ignore[assignment]
|
||||
@ -314,9 +290,7 @@ class CheckBatchCost:
|
||||
f"Skipping job {unified_object_id} because it is not a valid deployment info"
|
||||
)
|
||||
if prom_logger:
|
||||
prom_logger.record_check_batch_cost_error(
|
||||
"deployment_not_found"
|
||||
)
|
||||
prom_logger.record_check_batch_cost_error("deployment_not_found")
|
||||
continue
|
||||
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
|
||||
litellm_model_name = deployment_info.litellm_params.model
|
||||
@ -328,11 +302,7 @@ class CheckBatchCost:
|
||||
|
||||
# Pass deployment model_info so custom batch pricing
|
||||
# (input_cost_per_token_batches etc.) is used for cost calc
|
||||
deployment_model_info = (
|
||||
deployment_info.model_info.model_dump()
|
||||
if deployment_info.model_info
|
||||
else {}
|
||||
)
|
||||
deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}
|
||||
batch_cost, batch_usage, batch_models = (
|
||||
await calculate_batch_cost_and_usage(
|
||||
file_content_dictionary=file_content_as_dict,
|
||||
@ -379,9 +349,7 @@ class CheckBatchCost:
|
||||
|
||||
# Record batch duration (completed_at - created_at)
|
||||
if prom_logger and response.completed_at and response.created_at:
|
||||
duration_seconds = float(
|
||||
response.completed_at - response.created_at
|
||||
)
|
||||
duration_seconds = float(response.completed_at - response.created_at)
|
||||
if duration_seconds >= 0:
|
||||
prom_logger.record_managed_batch_duration(
|
||||
duration_seconds=duration_seconds,
|
||||
@ -390,9 +358,7 @@ class CheckBatchCost:
|
||||
)
|
||||
|
||||
# Track this job for the final metrics summary
|
||||
processed_models.append(
|
||||
(model_name, str(llm_provider) if llm_provider else None)
|
||||
)
|
||||
processed_models.append((model_name, str(llm_provider) if llm_provider else None))
|
||||
|
||||
# mark the job as complete
|
||||
try:
|
||||
|
||||
@ -33,7 +33,9 @@ class CheckResponsesCost:
|
||||
self.prisma_client: PrismaClient = prisma_client
|
||||
self.llm_router: Router = llm_router
|
||||
|
||||
async def _expire_stale_rows(self, cutoff: datetime, batch_size: int) -> int:
|
||||
async def _expire_stale_rows(
|
||||
self, cutoff: datetime, batch_size: int
|
||||
) -> int:
|
||||
"""Execute the bounded UPDATE that marks stale rows as 'stale_expired'.
|
||||
|
||||
Isolated so it can be swapped / mocked in tests without touching the
|
||||
@ -72,9 +74,7 @@ class CheckResponsesCost:
|
||||
rows per invocation to avoid overwhelming the DB when there is a large
|
||||
backlog.
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(
|
||||
days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS
|
||||
)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS)
|
||||
result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE)
|
||||
if result > 0:
|
||||
verbose_proxy_logger.warning(
|
||||
@ -105,7 +105,7 @@ class CheckResponsesCost:
|
||||
take=MAX_OBJECTS_PER_POLL_CYCLE,
|
||||
order={"created_at": "asc"},
|
||||
)
|
||||
|
||||
|
||||
verbose_proxy_logger.debug(f"Found {len(jobs)} response jobs to check")
|
||||
completed_jobs = []
|
||||
|
||||
@ -120,33 +120,29 @@ class CheckResponsesCost:
|
||||
# Get the stored response object to extract model information
|
||||
stored_response = job.file_object
|
||||
model_name = stored_response.get("model", None)
|
||||
|
||||
|
||||
# Decrypt the response ID
|
||||
responses_id_security, _, _ = (
|
||||
ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
|
||||
)
|
||||
|
||||
responses_id_security, _, _ = ResponsesIDSecurity()._decrypt_response_id(unified_object_id)
|
||||
|
||||
# Prepare metadata with model information for cost tracking
|
||||
litellm_metadata = {
|
||||
"user_api_key_user_id": job.created_by or "default-user-id",
|
||||
}
|
||||
|
||||
|
||||
# Add model information if available
|
||||
if model_name:
|
||||
litellm_metadata["model"] = model_name
|
||||
litellm_metadata["model_group"] = (
|
||||
model_name # Use same value for model_group
|
||||
)
|
||||
|
||||
litellm_metadata["model_group"] = model_name # Use same value for model_group
|
||||
|
||||
response = await litellm.aget_responses(
|
||||
response_id=responses_id_security,
|
||||
litellm_metadata=litellm_metadata,
|
||||
)
|
||||
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Response {unified_object_id} status: {response.status}, model: {model_name}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Skipping job {unified_object_id} due to error: {e}"
|
||||
@ -159,7 +155,7 @@ class CheckResponsesCost:
|
||||
f"Response {unified_object_id} is complete. Cost automatically tracked by aget_responses."
|
||||
)
|
||||
completed_jobs.append(job)
|
||||
|
||||
|
||||
elif response.status in ["failed", "cancelled"]:
|
||||
verbose_proxy_logger.info(
|
||||
f"Response {unified_object_id} has status {response.status}, marking as complete"
|
||||
@ -175,3 +171,4 @@ class CheckResponsesCost:
|
||||
verbose_proxy_logger.info(
|
||||
f"Marked {len(completed_jobs)} response jobs as completed"
|
||||
)
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
):
|
||||
"""
|
||||
Managed vector stores with target_model_names support.
|
||||
|
||||
|
||||
This class provides functionality to:
|
||||
- Create vector stores across multiple models
|
||||
- Retrieve vector stores by unified ID
|
||||
@ -77,14 +77,14 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> str:
|
||||
"""
|
||||
Generate the format string for the unified vector store ID.
|
||||
|
||||
|
||||
Format:
|
||||
litellm_proxy:vector_store;unified_id,<uuid>;target_model_names,<models>;resource_id,<vs_id>;model_id,<model_id>
|
||||
"""
|
||||
# VectorStoreCreateResponse is a TypedDict, so resource_object is a dictionary
|
||||
# Extract provider resource ID from the response
|
||||
provider_resource_id = resource_object.get("id", "")
|
||||
|
||||
|
||||
# Model ID is stored in hidden params if the response object supports it
|
||||
# For TypedDict responses, we need to check if _hidden_params was added
|
||||
hidden_params: Dict[str, Any] = {}
|
||||
@ -109,18 +109,20 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> VectorStoreCreateResponse:
|
||||
"""
|
||||
Create a vector store for a specific model.
|
||||
|
||||
|
||||
Args:
|
||||
llm_router: LiteLLM router instance
|
||||
model: Model name to create vector store for
|
||||
request_data: Request data for vector store creation
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
|
||||
|
||||
Returns:
|
||||
VectorStoreCreateResponse from the provider
|
||||
"""
|
||||
# Use the router to create the vector store
|
||||
response = await llm_router.avector_store_create(model=model, **request_data)
|
||||
response = await llm_router.avector_store_create(
|
||||
model=model, **request_data
|
||||
)
|
||||
return response
|
||||
|
||||
# ============================================================================
|
||||
@ -137,14 +139,14 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> VectorStoreCreateResponse:
|
||||
"""
|
||||
Create a vector store across multiple models.
|
||||
|
||||
|
||||
Args:
|
||||
create_request: Vector store creation request parameters
|
||||
llm_router: LiteLLM router instance
|
||||
target_model_names_list: List of target model names
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
user_api_key_dict: User API key authentication details
|
||||
|
||||
|
||||
Returns:
|
||||
VectorStoreCreateResponse with unified ID
|
||||
"""
|
||||
@ -194,7 +196,7 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
# VectorStoreCreateResponse is a TypedDict, so we need to create a new dict with the unified ID
|
||||
response = responses[0].copy()
|
||||
response["id"] = unified_id
|
||||
|
||||
|
||||
verbose_logger.info(
|
||||
f"Successfully created managed vector store with unified ID: {unified_id}"
|
||||
)
|
||||
@ -210,13 +212,13 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List vector stores created by a user.
|
||||
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User API key authentication details
|
||||
limit: Maximum number of vector stores to return
|
||||
after: Cursor for pagination
|
||||
order: Sort order ('asc' or 'desc')
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with list of vector stores and pagination info
|
||||
"""
|
||||
@ -236,23 +238,23 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has access to a vector store.
|
||||
|
||||
|
||||
Args:
|
||||
vector_store_id: The unified vector store ID
|
||||
user_api_key_dict: User API key authentication details
|
||||
|
||||
|
||||
Returns:
|
||||
True if user has access, False otherwise
|
||||
"""
|
||||
is_unified_id = is_base64_encoded_unified_id(vector_store_id)
|
||||
|
||||
|
||||
if is_unified_id:
|
||||
# Check access for managed vector store
|
||||
return await self.can_user_access_unified_resource_id(
|
||||
vector_store_id,
|
||||
user_api_key_dict,
|
||||
)
|
||||
|
||||
|
||||
# Not a managed vector store, allow access
|
||||
return True
|
||||
|
||||
@ -261,22 +263,24 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has access to a managed vector store in request data.
|
||||
|
||||
|
||||
Args:
|
||||
data: Request data containing vector_store_id
|
||||
user_api_key_dict: User API key authentication details
|
||||
|
||||
|
||||
Returns:
|
||||
True if this is a managed vector store and user has access
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have access
|
||||
"""
|
||||
vector_store_id = cast(Optional[str], data.get("vector_store_id"))
|
||||
is_unified_id = (
|
||||
is_base64_encoded_unified_id(vector_store_id) if vector_store_id else False
|
||||
is_base64_encoded_unified_id(vector_store_id)
|
||||
if vector_store_id
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
if is_unified_id and vector_store_id:
|
||||
if await self.can_user_access_unified_resource_id(
|
||||
vector_store_id, user_api_key_dict
|
||||
@ -287,7 +291,7 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
status_code=403,
|
||||
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
|
||||
)
|
||||
|
||||
|
||||
return False
|
||||
|
||||
# ============================================================================
|
||||
@ -303,18 +307,18 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> Union[Exception, str, Dict, None]:
|
||||
"""
|
||||
Pre-call hook to handle vector store operations.
|
||||
|
||||
|
||||
This hook intercepts vector store requests and:
|
||||
- Validates access for managed vector stores
|
||||
- Transforms unified IDs to provider-specific IDs
|
||||
- Adds model routing information
|
||||
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User API key authentication details
|
||||
cache: Cache instance
|
||||
data: Request data
|
||||
call_type: Type of call being made
|
||||
|
||||
|
||||
Returns:
|
||||
Modified request data or None
|
||||
"""
|
||||
@ -326,40 +330,40 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
# Handle vector store search operations
|
||||
if call_type == "avector_store_search":
|
||||
vector_store_id = data.get("vector_store_id")
|
||||
|
||||
|
||||
if vector_store_id:
|
||||
# Check if it's a managed vector store ID
|
||||
decoded_id = is_base64_encoded_unified_id(vector_store_id)
|
||||
|
||||
|
||||
if decoded_id:
|
||||
verbose_logger.debug(
|
||||
f"Processing managed vector store search: {vector_store_id}"
|
||||
)
|
||||
|
||||
|
||||
# Check access
|
||||
has_access = await self.can_user_access_unified_resource_id(
|
||||
vector_store_id, user_api_key_dict
|
||||
)
|
||||
|
||||
|
||||
if not has_access:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}",
|
||||
)
|
||||
|
||||
|
||||
# Parse the unified ID to extract components
|
||||
parsed_id = parse_unified_id(vector_store_id)
|
||||
|
||||
|
||||
if parsed_id:
|
||||
# Extract the model ID and provider resource ID
|
||||
model_id = parsed_id.get("model_id")
|
||||
provider_resource_id = parsed_id.get("provider_resource_id")
|
||||
target_model_names = parsed_id.get("target_model_names", [])
|
||||
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Decoded vector store - model_id: {model_id}, provider_resource_id: {provider_resource_id}, target_model_names: {target_model_names}"
|
||||
)
|
||||
|
||||
|
||||
# Determine which model to use for routing
|
||||
# Priority: model_id (deployment ID) > first target_model_name
|
||||
routing_model = None
|
||||
@ -367,28 +371,28 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
routing_model = model_id
|
||||
elif target_model_names and len(target_model_names) > 0:
|
||||
routing_model = target_model_names[0]
|
||||
|
||||
|
||||
# Set the model for routing
|
||||
if routing_model:
|
||||
data["model"] = routing_model
|
||||
verbose_logger.info(
|
||||
f"Routing vector store search to model: {routing_model}"
|
||||
)
|
||||
|
||||
|
||||
# Replace the unified ID with the provider-specific ID
|
||||
if provider_resource_id:
|
||||
data["vector_store_id"] = provider_resource_id
|
||||
verbose_logger.debug(
|
||||
f"Replaced unified ID with provider resource ID: {provider_resource_id}"
|
||||
)
|
||||
|
||||
|
||||
# Handle vector store retrieve/delete operations
|
||||
elif call_type in ("avector_store_retrieve", "avector_store_delete"):
|
||||
await self.check_managed_vector_store_access(data, user_api_key_dict)
|
||||
|
||||
|
||||
# If it's a managed vector store, we'll handle it in the endpoint
|
||||
# No need to transform here as the endpoint will route to the hook
|
||||
|
||||
|
||||
return data
|
||||
|
||||
# ============================================================================
|
||||
@ -403,15 +407,15 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> Any:
|
||||
"""
|
||||
Post-call hook to transform responses.
|
||||
|
||||
|
||||
This hook can be used to transform responses if needed.
|
||||
For now, it just passes through the response.
|
||||
|
||||
|
||||
Args:
|
||||
data: Request data
|
||||
user_api_key_dict: User API key authentication details
|
||||
response: Response from the provider
|
||||
|
||||
|
||||
Returns:
|
||||
Potentially modified response
|
||||
"""
|
||||
@ -432,21 +436,21 @@ class _PROXY_LiteLLMManagedVectorStores(
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Filter deployments based on vector store availability.
|
||||
|
||||
|
||||
This is used by the router to select only deployments that have
|
||||
the vector store available.
|
||||
|
||||
|
||||
Note: This method signature is a compromise between CustomLogger and BaseManagedResource
|
||||
parent classes which have incompatible signatures. The type: ignore[override] is necessary
|
||||
due to this multiple inheritance conflict.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
healthy_deployments: List of healthy deployments
|
||||
messages: Messages (unused for vector stores, required by CustomLogger interface)
|
||||
request_kwargs: Request kwargs containing vector_store_id and mappings
|
||||
parent_otel_span: OpenTelemetry span for tracing
|
||||
|
||||
|
||||
Returns:
|
||||
Filtered list of deployments
|
||||
"""
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Enterprise internal user management endpoints
|
||||
"""
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
@ -147,12 +147,12 @@ async def list_vector_stores(
|
||||
vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
|
||||
|
||||
# Also clean up in-memory registry to remove any deleted vector stores
|
||||
if litellm.vector_store_registry is not None:
|
||||
db_vector_store_ids = {
|
||||
vs.get("vector_store_id")
|
||||
for vs in vector_stores_from_db
|
||||
vs.get("vector_store_id")
|
||||
for vs in vector_stores_from_db
|
||||
if vs.get("vector_store_id")
|
||||
}
|
||||
# Remove any in-memory vector stores that no longer exist in database
|
||||
|
||||
@ -39,23 +39,15 @@ class EmailEvent(str, enum.Enum):
|
||||
soft_budget_crossed = "Soft Budget Crossed"
|
||||
max_budget_alert = "Max Budget Alert"
|
||||
|
||||
|
||||
class EmailEventSettings(BaseModel):
|
||||
event: EmailEvent
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EmailEventSettingsUpdateRequest(BaseModel):
|
||||
settings: List[EmailEventSettings]
|
||||
|
||||
|
||||
class EmailEventSettingsResponse(BaseModel):
|
||||
settings: List[EmailEventSettings]
|
||||
|
||||
|
||||
class DefaultEmailSettings(BaseModel):
|
||||
"""Default settings for email events"""
|
||||
|
||||
settings: Dict[EmailEvent, bool] = Field(
|
||||
default_factory=lambda: {
|
||||
EmailEvent.virtual_key_created: True, # On by default
|
||||
@ -65,12 +57,10 @@ class DefaultEmailSettings(BaseModel):
|
||||
EmailEvent.max_budget_alert: True, # On by default
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, bool]:
|
||||
"""Convert to dictionary with string keys for storage"""
|
||||
return {event.value: enabled for event, enabled in self.settings.items()}
|
||||
|
||||
@classmethod
|
||||
def get_defaults(cls) -> Dict[str, bool]:
|
||||
"""Get the default settings as a dictionary with string keys"""
|
||||
return cls().to_dict()
|
||||
return cls().to_dict()
|
||||
@ -6,6 +6,7 @@ Always uses fastuuid for performance.
|
||||
|
||||
import fastuuid as _uuid # type: ignore
|
||||
|
||||
|
||||
# Expose a module-like alias so callers can use: uuid.uuid4()
|
||||
uuid = _uuid
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Dict, Optional
|
||||
|
||||
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
|
||||
|
||||
|
||||
# HTTP status code -> Anthropic error type
|
||||
# Source: https://docs.anthropic.com/en/api/errors
|
||||
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
|
||||
# Known Anthropic error types
|
||||
# Source: https://docs.anthropic.com/en/api/errors
|
||||
AnthropicErrorType = Literal[
|
||||
|
||||
@ -97,7 +97,7 @@ def _build_reasoning_item(
|
||||
|
||||
|
||||
def _reasoning_item_to_response_input(
|
||||
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]],
|
||||
r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a stored ChatCompletionReasoningItem back to a Responses API input item."""
|
||||
r_input: Dict[str, Any] = {
|
||||
|
||||
@ -5,6 +5,7 @@ Auto-detect content type per message: code, JSON, or text.
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
_CODE_KEYWORDS = re.compile(
|
||||
r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b"
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union
|
||||
|
||||
|
||||
FileContentProvider = Literal[
|
||||
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
|
||||
]
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
"""
|
||||
Google GenAI Adapters for LiteLLM
|
||||
|
||||
This module provides adapters for transforming Google GenAI generate_content requests
|
||||
This module provides adapters for transforming Google GenAI generate_content requests
|
||||
to/from LiteLLM completion format with full support for:
|
||||
- Text content transformation
|
||||
- Tool calling (function declarations, function calls, function responses)
|
||||
- Tool calling (function declarations, function calls, function responses)
|
||||
- Streaming (both regular and tool calling)
|
||||
- Mixed content (text + tool calls)
|
||||
"""
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""
|
||||
Handles Batching + sending Httpx Post requests to slack
|
||||
Handles Batching + sending Httpx Post requests to slack
|
||||
|
||||
Slack alerts are sent every 10s or when events are greater than X events
|
||||
Slack alerts are sent every 10s or when events are greater than X events
|
||||
|
||||
see custom_batch_logger.py for more details / defaults
|
||||
see custom_batch_logger.py for more details / defaults
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@ -18,7 +18,7 @@ else:
|
||||
|
||||
|
||||
def process_slack_alerting_variables(
|
||||
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]],
|
||||
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
|
||||
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
|
||||
"""
|
||||
process alert_to_webhook_url
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Base class for Additional Logging Utils for CustomLoggers
|
||||
Base class for Additional Logging Utils for CustomLoggers
|
||||
|
||||
- Health Check for the logging util
|
||||
- Get Request / Response Payload for the logging util
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Custom Logger that handles batching logic
|
||||
Custom Logger that handles batching logic
|
||||
|
||||
Use this if you want your logs to be stored in memory and flushed periodically.
|
||||
"""
|
||||
|
||||
@ -9,6 +9,7 @@ import polars as pl
|
||||
|
||||
from .schema import FOCUS_NORMALIZED_SCHEMA
|
||||
|
||||
|
||||
_TAG_KEYS = (
|
||||
"team_id",
|
||||
"team_alias",
|
||||
|
||||
@ -105,7 +105,7 @@ def _remove_nulls(x: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def get_traces_and_spans_from_payload(
|
||||
payload: List[Dict[str, Any]],
|
||||
payload: List[Dict[str, Any]]
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Separate traces and spans from payload.
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""
|
||||
s3 Bucket Logging Integration
|
||||
|
||||
async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
|
||||
async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
|
||||
async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
|
||||
async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3
|
||||
NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
|
||||
"""
|
||||
|
||||
|
||||
@ -5,28 +5,28 @@ This module provides SDK methods for Google's Interactions API.
|
||||
|
||||
Usage:
|
||||
import litellm
|
||||
|
||||
|
||||
# Create an interaction with a model
|
||||
response = litellm.interactions.create(
|
||||
model="gemini-2.5-flash",
|
||||
input="Hello, how are you?"
|
||||
)
|
||||
|
||||
|
||||
# Create an interaction with an agent
|
||||
response = litellm.interactions.create(
|
||||
agent="deep-research-pro-preview-12-2025",
|
||||
input="Research the current state of cancer research"
|
||||
)
|
||||
|
||||
|
||||
# Async version
|
||||
response = await litellm.interactions.acreate(...)
|
||||
|
||||
|
||||
# Get an interaction
|
||||
response = litellm.interactions.get(interaction_id="...")
|
||||
|
||||
|
||||
# Delete an interaction
|
||||
result = litellm.interactions.delete(interaction_id="...")
|
||||
|
||||
|
||||
# Cancel an interaction
|
||||
result = litellm.interactions.cancel(interaction_id="...")
|
||||
|
||||
|
||||
@ -8,25 +8,25 @@ Per OpenAPI spec (https://ai.google.dev/static/api/interactions.openapi.json):
|
||||
|
||||
Usage:
|
||||
import litellm
|
||||
|
||||
|
||||
# Create an interaction with a model
|
||||
response = litellm.interactions.create(
|
||||
model="gemini-2.5-flash",
|
||||
input="Hello, how are you?"
|
||||
)
|
||||
|
||||
|
||||
# Create an interaction with an agent
|
||||
response = litellm.interactions.create(
|
||||
agent="deep-research-pro-preview-12-2025",
|
||||
input="Research the current state of cancer research"
|
||||
)
|
||||
|
||||
|
||||
# Async version
|
||||
response = await litellm.interactions.acreate(...)
|
||||
|
||||
|
||||
# Get an interaction
|
||||
response = litellm.interactions.get(interaction_id="...")
|
||||
|
||||
|
||||
# Delete an interaction
|
||||
result = litellm.interactions.delete(interaction_id="...")
|
||||
"""
|
||||
|
||||
@ -994,8 +994,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if turn_off_message_logging is True:
|
||||
_metadata["raw_request"] = "redacted by litellm. \
|
||||
_metadata["raw_request"] = (
|
||||
"redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
)
|
||||
else:
|
||||
curl_command = self._get_request_curl_command(
|
||||
api_base=additional_args.get("api_base", ""),
|
||||
@ -1029,8 +1031,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
_metadata["raw_request"] = "Unable to Log \
|
||||
raw request: {}".format(str(e))
|
||||
_metadata["raw_request"] = (
|
||||
"Unable to Log \
|
||||
raw request: {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
|
||||
try:
|
||||
self.logger_fn(
|
||||
|
||||
@ -5533,7 +5533,9 @@ def default_response_schema_prompt(response_schema: dict) -> str:
|
||||
prompt_str = """Use this JSON schema:
|
||||
```json
|
||||
{}
|
||||
```""".format(response_schema)
|
||||
```""".format(
|
||||
response_schema
|
||||
)
|
||||
return prompt_str
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""
|
||||
This is a cache for LangfuseLoggers.
|
||||
|
||||
Langfuse Python SDK initializes a thread for each client.
|
||||
Langfuse Python SDK initializes a thread for each client.
|
||||
|
||||
This ensures we do
|
||||
This ensures we do
|
||||
1. Proper cleanup of Langfuse initialized clients.
|
||||
2. Re-use created langfuse clients.
|
||||
"""
|
||||
|
||||
@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE parsing helpers (module-level to keep the class lean)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -4,10 +4,10 @@ Support for o1 and o3 model families
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
Translations handled by LiteLLM:
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- Logprobs => drop param (if user opts in to dropping param)
|
||||
- Temperature => drop param (if user opts in to dropping param)
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
|
||||
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
|
||||
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
@ -16,6 +16,7 @@ from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
|
||||
|
||||
BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1"
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Legacy /v1/embedding handler for Bedrock Cohere.
|
||||
Legacy /v1/embedding handler for Bedrock Cohere.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
@ -13,6 +13,7 @@ from typing import Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-built response templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Cost calculator for Dashscope Chat models.
|
||||
Cost calculator for Dashscope Chat models.
|
||||
|
||||
Handles tiered pricing and prompt caching scenarios.
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||
|
||||
Calls done in OpenAI/openai.py as DataRobot is openai-compatible.
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format.
|
||||
Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Cost calculator for DeepSeek Chat models.
|
||||
Cost calculator for DeepSeek Chat models.
|
||||
|
||||
Handles prompt caching scenario.
|
||||
"""
|
||||
|
||||
@ -22,6 +22,7 @@ from litellm.types.utils import all_litellm_params
|
||||
|
||||
from ..common_utils import ElevenLabsException
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
@ -55,7 +55,7 @@ def _convert_image_to_gemini_format(image_file) -> Dict[str, str]:
|
||||
|
||||
|
||||
def _usage_video_resolution_from_parameters(
|
||||
parameters: Dict[str, Any],
|
||||
parameters: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
"""Normalize Veo ``parameters.resolution`` for usage and cost tracking."""
|
||||
res = parameters.get("resolution")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
|
||||
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
|
||||
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format.
|
||||
Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||
|
||||
Calls done in OpenAI/openai.py as Novita AI is openai-compatible.
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
|
||||
Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
|
||||
|
||||
This is OpenAI compatible
|
||||
This is OpenAI compatible
|
||||
|
||||
This file only contains param mapping logic
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Nvidia NIM embeddings endpoint: https://docs.api.nvidia.com/nim/reference/nvidia-nv-embedqa-e5-v5-infer
|
||||
|
||||
This is OpenAI compatible
|
||||
This is OpenAI compatible
|
||||
|
||||
This file only contains param mapping logic
|
||||
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
"""
|
||||
Support for o1/o3 model family
|
||||
Support for o1/o3 model family
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
Translations handled by LiteLLM:
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- Logprobs => drop param (if user opts in to dropping param)
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- Logprobs => drop param (if user opts in to dropping param)
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload
|
||||
|
||||
@ -201,7 +201,7 @@ class BaseOpenAILLM:
|
||||
|
||||
@staticmethod
|
||||
def get_openai_client_initialization_param_fields(
|
||||
client_type: Literal["openai", "azure"],
|
||||
client_type: Literal["openai", "azure"]
|
||||
) -> Tuple[str, ...]:
|
||||
"""Returns a tuple of fields that are used to initialize the OpenAI client"""
|
||||
if client_type == "openai":
|
||||
|
||||
@ -49,6 +49,7 @@ from litellm.types.utils import (
|
||||
)
|
||||
from litellm.llms.openrouter.common_utils import OpenRouterException
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`
|
||||
|
||||
In the Huggingface TGI format.
|
||||
In the Huggingface TGI format.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke`
|
||||
|
||||
In the Huggingface TGI format.
|
||||
In the Huggingface TGI format.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
@ -207,7 +207,7 @@ def resolve_resource_group(sources: List[Source]) -> Optional[str]:
|
||||
|
||||
|
||||
def _parse_service_key_once(
|
||||
service_key: Optional[Union[str, dict]],
|
||||
service_key: Optional[Union[str, dict]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Pre-parse service_key if it's a string to avoid repeated JSON parsing.
|
||||
|
||||
@ -14,6 +14,7 @@ from ...openai_like.chat.transformation import OpenAIGPTConfig
|
||||
|
||||
from ..utils import SnowflakeBaseConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||
|
||||
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Support for OpenAI's `/v1/embeddings` endpoint.
|
||||
Support for OpenAI's `/v1/embeddings` endpoint.
|
||||
|
||||
Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
|
||||
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic for context caching.
|
||||
Transformation logic for context caching.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
@ -19,7 +19,7 @@ from ..gemini.transformation import (
|
||||
|
||||
|
||||
def get_first_continuous_block_idx(
|
||||
filtered_messages: List[Tuple[int, AllMessageValues]], # (idx, message)
|
||||
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
|
||||
) -> int:
|
||||
"""
|
||||
Find the array index that ends the first continuous sequence of message blocks.
|
||||
|
||||
@ -632,14 +632,16 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||
contents.append(ContentType(role="user", parts=tool_call_responses))
|
||||
|
||||
if len(contents) == 0:
|
||||
verbose_logger.warning("""
|
||||
verbose_logger.warning(
|
||||
"""
|
||||
No contents in messages. Contents are required. See
|
||||
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
|
||||
If the original request did not comply to OpenAI API requirements it should have failed by now,
|
||||
but LiteLLM does not check for missing messages.
|
||||
Setting an empty content to prevent an 400 error.
|
||||
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
|
||||
""")
|
||||
"""
|
||||
)
|
||||
contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
|
||||
return contents
|
||||
except Exception as e:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
|
||||
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
@ -139,7 +139,7 @@ class VertexTextToSpeechAPI(VertexLLM):
|
||||
########## End of logging ############
|
||||
####### Send the request ###################
|
||||
if _is_async is True:
|
||||
return self.async_audio_speech( # type: ignore
|
||||
return self.async_audio_speech( # type:ignore
|
||||
logging_obj=logging_obj, url=url, headers=headers, request=request
|
||||
)
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
|
||||
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
|
||||
|
||||
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""
|
||||
This module is used to transform the request and response for the Voyage contextualized embeddings API.
|
||||
This would be used for all the contextualized embeddings models in Voyage.
|
||||
This module is used to transform the request and response for the Voyage contextualized embeddings API.
|
||||
This would be used for all the contextualized embeddings models in Voyage.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@ -324,7 +324,7 @@ class CustomOpenAPISpec:
|
||||
|
||||
@staticmethod
|
||||
def add_chat_completion_request_schema(
|
||||
openapi_schema: Dict[str, Any],
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
|
||||
@ -380,7 +380,7 @@ class CustomOpenAPISpec:
|
||||
|
||||
@staticmethod
|
||||
def add_responses_api_request_schema(
|
||||
openapi_schema: Dict[str, Any],
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
|
||||
@ -410,7 +410,7 @@ class CustomOpenAPISpec:
|
||||
|
||||
@staticmethod
|
||||
def add_llm_api_request_schema_body(
|
||||
openapi_schema: Dict[str, Any],
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add LLM API request schema bodies to OpenAPI specification for documentation.
|
||||
|
||||
@ -257,7 +257,7 @@ async def get_form_data(request: Request) -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def convert_upload_files_to_file_data(
|
||||
form_data: Dict[str, Any],
|
||||
form_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert FastAPI UploadFile objects to file data tuples for litellm.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Contains utils used by OpenAI compatible endpoints
|
||||
Contains utils used by OpenAI compatible endpoints
|
||||
"""
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
What is this?
|
||||
What is this?
|
||||
|
||||
CRUD endpoints for managing pass-through endpoints
|
||||
"""
|
||||
|
||||
@ -34,7 +34,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
|
||||
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
|
||||
raise
|
||||
# If an error occurs, the view does not exist, so create it
|
||||
await db.execute_raw("""
|
||||
await db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
@ -46,7 +47,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
|
||||
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
verbose_logger.debug("LiteLLM_VerificationTokenView Created!")
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ every text fragment.
|
||||
|
||||
from typing import Any, Callable, Dict, FrozenSet, Iterator, List
|
||||
|
||||
|
||||
# Call types whose body carries free-form chat / prompt text that
|
||||
# text-content guardrails (banned keywords, content moderation, secret
|
||||
# detection, …) should inspect. The proxy ingress passes ``route_type``
|
||||
|
||||
@ -4,6 +4,7 @@ from litellm.types.guardrails import SupportedGuardrailIntegrations
|
||||
|
||||
from .akto import AktoGuardrail
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ The actual skill logic is in litellm/llms/litellm_proxy/skills/.
|
||||
|
||||
Usage:
|
||||
from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook
|
||||
|
||||
|
||||
# Register hook in proxy
|
||||
litellm.callbacks.append(SkillsInjectionHook())
|
||||
"""
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""
|
||||
BUDGET MANAGEMENT
|
||||
|
||||
All /budget management endpoints
|
||||
All /budget management endpoints
|
||||
|
||||
/budget/new
|
||||
/budget/new
|
||||
/budget/info
|
||||
/budget/update
|
||||
/budget/delete
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""
|
||||
CUSTOMER MANAGEMENT
|
||||
|
||||
All /customer management endpoints
|
||||
All /customer management endpoints
|
||||
|
||||
/customer/new
|
||||
/customer/new
|
||||
/customer/info
|
||||
/customer/update
|
||||
/customer/delete
|
||||
|
||||
@ -529,7 +529,7 @@ async def _update_existing_team_model_assignment(
|
||||
"""
|
||||
|
||||
def _get_team_public_model_name(
|
||||
model_info: Optional[Union[dict, str]],
|
||||
model_info: Optional[Union[dict, str]]
|
||||
) -> Optional[str]:
|
||||
if isinstance(model_info, dict):
|
||||
value = model_info.get("team_public_model_name")
|
||||
|
||||
@ -7,7 +7,7 @@ variables.
|
||||
|
||||
Environment Variables:
|
||||
- MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL
|
||||
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
|
||||
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
|
||||
- MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL
|
||||
|
||||
If these are not set, the default Microsoft endpoints are used.
|
||||
|
||||
@ -4347,7 +4347,9 @@ async def list_team(
|
||||
except Exception as e:
|
||||
team_exception = """Invalid team object for team_id: {}. team_object={}.
|
||||
Error: {}
|
||||
""".format(team.team_id, team.model_dump(), str(e))
|
||||
""".format(
|
||||
team.team_id, team.model_dump(), str(e)
|
||||
)
|
||||
verbose_proxy_logger.exception(team_exception)
|
||||
continue
|
||||
# Sort the responses by team_alias
|
||||
|
||||
@ -3,7 +3,7 @@ User Agent Analytics Endpoints
|
||||
|
||||
This module provides optimized endpoints for tracking user agent activity metrics including:
|
||||
- Daily Active Users (DAU) by tags for configurable number of days
|
||||
- Weekly Active Users (WAU) by tags for configurable number of weeks
|
||||
- Weekly Active Users (WAU) by tags for configurable number of weeks
|
||||
- Monthly Active Users (MAU) by tags for configurable number of months
|
||||
- Summary analytics by tags
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from litellm.litellm_core_utils.litellm_logging import (
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
|
||||
|
||||
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
|
||||
"POST /v0/agents": "cursor:agent:create",
|
||||
"GET /v0/agents": "cursor:agent:list",
|
||||
|
||||
@ -292,7 +292,9 @@ class ProxyInitializationHelpers:
|
||||
_endpoint_str = (
|
||||
f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\"
|
||||
)
|
||||
curl_command = _endpoint_str + """
|
||||
curl_command = (
|
||||
_endpoint_str
|
||||
+ """
|
||||
--header 'Content-Type: application/json' \\
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
@ -305,6 +307,7 @@ class ProxyInitializationHelpers:
|
||||
}'
|
||||
\n
|
||||
"""
|
||||
)
|
||||
print() # noqa
|
||||
print( # noqa
|
||||
'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n'
|
||||
@ -380,9 +383,11 @@ class ProxyInitializationHelpers:
|
||||
with open(os.devnull, "w") as devnull:
|
||||
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
||||
except Exception as e:
|
||||
print(f"""
|
||||
print( # noqa
|
||||
f"""
|
||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||
""") # noqa # noqa
|
||||
"""
|
||||
) # noqa
|
||||
|
||||
@staticmethod
|
||||
def _is_port_in_use(port):
|
||||
|
||||
@ -2688,9 +2688,11 @@ def run_ollama_serve():
|
||||
with open(os.devnull, "w") as devnull:
|
||||
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"""
|
||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _get_process_rss_mb() -> Optional[float]:
|
||||
|
||||
@ -3184,14 +3184,16 @@ async def provider_budgets() -> ProviderBudgetResponse:
|
||||
async def get_spend_by_tags(
|
||||
prisma_client: PrismaClient, start_date=None, end_date=None
|
||||
):
|
||||
response = await prisma_client.db.query_raw("""
|
||||
response = await prisma_client.db.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||
COUNT(*) AS log_count,
|
||||
SUM(spend) AS total_spend
|
||||
FROM "LiteLLM_SpendLogs"
|
||||
GROUP BY individual_request_tag;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -2712,7 +2712,8 @@ class PrismaClient:
|
||||
required_view = "LiteLLM_VerificationTokenView"
|
||||
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
||||
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
|
||||
ret = await self.db.query_raw(f"""
|
||||
ret = await self.db.query_raw(
|
||||
f"""
|
||||
WITH existing_views AS (
|
||||
SELECT viewname
|
||||
FROM pg_views
|
||||
@ -2724,7 +2725,8 @@ class PrismaClient:
|
||||
(SELECT COUNT(*) FROM existing_views) AS view_count,
|
||||
ARRAY_AGG(viewname) AS view_names
|
||||
FROM existing_views
|
||||
""")
|
||||
"""
|
||||
)
|
||||
expected_total_views = len(expected_views)
|
||||
if ret[0]["view_count"] == expected_total_views:
|
||||
verbose_proxy_logger.info("All necessary views exist!")
|
||||
@ -2733,7 +2735,8 @@ class PrismaClient:
|
||||
## check if required view exists ##
|
||||
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
|
||||
await self.health_check() # make sure we can connect to db
|
||||
await self.db.execute_raw("""
|
||||
await self.db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
@ -2743,7 +2746,8 @@ class PrismaClient:
|
||||
t.rpm_limit AS team_rpm_limit
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"LiteLLM_VerificationTokenView Created in DB!"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
What is this?
|
||||
What is this?
|
||||
|
||||
Logging Pass-Through Endpoints
|
||||
"""
|
||||
|
||||
@ -826,7 +826,7 @@ class Router:
|
||||
|
||||
@staticmethod
|
||||
def _normalize_strategy(
|
||||
strategy: Union[RoutingStrategy, str, None],
|
||||
strategy: Union[RoutingStrategy, str, None]
|
||||
) -> Optional[str]:
|
||||
if strategy is None:
|
||||
return None
|
||||
|
||||
@ -103,7 +103,7 @@ def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str
|
||||
|
||||
|
||||
def _recent_tool_results(
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
messages: Optional[List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract the current turn's tool result payloads from the request messages.
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user