From 5bafa8b3a23c96d720fdc0f8f6eb380787362bfd Mon Sep 17 00:00:00 2001 From: user <70670632+stuxf@users.noreply.github.com> Date: Thu, 7 May 2026 23:04:52 +0000 Subject: [PATCH] Drop dep bumps + black-26 reformat to clear fork CI policy PR was blocked by .github/workflows/guard-fork-dependencies.yml: fork PRs cannot modify uv.lock. Reverting: - uv.lock + pyproject.toml black bump (24.10.0 -> 26.3.1) and the 295 files of mechanical Black 26 reformat coupled to it - pyproject.toml diskcache extra change (kept the runtime mitigation in litellm/caching/disk_cache.py via JSONDisk) Kept: - Dockerfile cache narrowing (drops ~660 MB of uv build cache that surfaced cached setuptools as CVE findings) - litellm/caching/disk_cache.py: dc.JSONDisk to neutralize CVE-2025-69872 - ui/litellm-dashboard/package-lock.json + litellm-js/spend-logs/package-lock.json: next/postcss/hono/uuid CVE bumps (these are not blocked by the fork guard) - tests/test_litellm/caching/test_disk_cache.py - tests/code_coverage_tests/liccheck.ini: harmless black authorization Black + gitpython + langchain dep upgrades will need a follow-up from a maintainer pushing a branch in the canonical BerriAI/litellm repo. Co-Authored-By: Claude Opus 4.7 (1M context) --- ...to_update_price_and_context_window_file.py | 90 +-- .../workflows/run_llm_translation_tests.py | 512 ++++++------- ci_cd/run_migration.py | 1 + cookbook/benchmark/benchmark.py | 1 + .../braintrust_prompt_wrapper_server.py | 1 + cookbook/livekit_agent_sdk/main.py | 2 +- cookbook/misc/migrate_proxy_config.py | 4 +- .../mock_bedrock_guardrail_server.py | 6 +- db_scripts/create_views.py | 6 +- .../enterprise_callbacks/callback_controls.py | 139 ++-- .../send_emails/base_email.py | 14 +- .../send_emails/sendgrid_email.py | 3 +- .../send_emails/smtp_email.py | 1 - .../litellm_core_utils/litellm_logging.py | 1 - .../proxy/audit_logging_endpoints.py | 10 +- .../litellm_enterprise/proxy/auth/__init__.py | 2 +- .../proxy/common_utils/check_batch_cost.py | 70 +- .../common_utils/check_responses_cost.py | 33 +- .../proxy/hooks/managed_vector_stores.py | 92 +-- .../internal_user_endpoints.py | 1 + .../proxy/vector_stores/endpoints.py | 6 +- .../types/enterprise_callbacks/send_emails.py | 12 +- litellm/_uuid.py | 1 + .../exceptions/exception_mapping_utils.py | 1 + .../exceptions/exceptions.py | 1 + .../transformation.py | 2 +- litellm/compression/content_detection.py | 1 + litellm/files/types.py | 1 + litellm/google_genai/adapters/__init__.py | 4 +- .../SlackAlerting/batching_handler.py | 6 +- litellm/integrations/SlackAlerting/utils.py | 2 +- .../integrations/additional_logging_utils.py | 2 +- litellm/integrations/custom_batch_logger.py | 2 +- litellm/integrations/focus/transformer.py | 1 + litellm/integrations/opik/utils.py | 2 +- litellm/integrations/s3_v2.py | 4 +- litellm/interactions/__init__.py | 12 +- litellm/interactions/main.py | 10 +- litellm/litellm_core_utils/litellm_logging.py | 12 +- .../prompt_templates/factory.py | 4 +- .../specialty_caches/dynamic_logging_cache.py | 4 +- .../messages/agentic_streaming_iterator.py | 1 + .../azure/chat/o_series_transformation.py | 8 +- .../azure_ai/embed/cohere_transformation.py | 2 +- .../llms/azure_ai/rerank/transformation.py | 2 +- .../embed/amazon_titan_g1_transformation.py | 2 +- .../bedrock/embed/cohere_transformation.py | 2 +- .../bedrock_mantle/chat/transformation.py | 1 + litellm/llms/cohere/embed/handler.py | 2 +- litellm/llms/custom_httpx/mock_transport.py | 1 + litellm/llms/dashscope/cost_calculator.py | 2 +- litellm/llms/datarobot/chat/transformation.py | 2 +- .../llms/deepinfra/rerank/transformation.py | 2 +- litellm/llms/deepseek/cost_calculator.py | 2 +- .../text_to_speech/transformation.py | 1 + litellm/llms/gemini/videos/transformation.py | 2 +- .../llms/infinity/rerank/transformation.py | 2 +- litellm/llms/jina_ai/rerank/transformation.py | 2 +- .../llms/lm_studio/embed/transformation.py | 2 +- litellm/llms/novita/chat/transformation.py | 2 +- .../llms/nvidia_nim/chat/transformation.py | 4 +- litellm/llms/nvidia_nim/embed.py | 2 +- .../openai/chat/o_series_transformation.py | 12 +- litellm/llms/openai/common_utils.py | 2 +- .../image_generation/transformation.py | 1 + .../sagemaker/completion/transformation.py | 2 +- .../sagemaker/embedding/transformation.py | 2 +- litellm/llms/sap/credentials.py | 2 +- litellm/llms/snowflake/chat/transformation.py | 1 + litellm/llms/together_ai/chat.py | 2 +- litellm/llms/together_ai/embed.py | 2 +- .../llms/together_ai/rerank/transformation.py | 2 +- .../context_caching/transformation.py | 4 +- .../llms/vertex_ai/gemini/transformation.py | 6 +- .../batch_embed_content_transformation.py | 2 +- .../text_to_speech/text_to_speech_handler.py | 2 +- .../llms/vllm/completion/transformation.py | 2 +- .../embedding/transformation_contextual.py | 4 +- .../proxy/common_utils/custom_openapi_spec.py | 6 +- .../proxy/common_utils/http_parsing_utils.py | 2 +- .../common_utils/openai_endpoint_utils.py | 2 +- .../pass_through_endpoints.py | 2 +- litellm/proxy/db/create_views.py | 6 +- litellm/proxy/guardrails/_content_utils.py | 1 + .../guardrail_hooks/akto/__init__.py | 1 + .../proxy/hooks/litellm_skills/__init__.py | 2 +- .../budget_management_endpoints.py | 4 +- .../customer_endpoints.py | 4 +- .../model_management_endpoints.py | 2 +- .../sso/custom_microsoft_sso.py | 2 +- .../management_endpoints/team_endpoints.py | 4 +- .../user_agent_analytics_endpoints.py | 2 +- .../cursor_passthrough_logging_handler.py | 1 + litellm/proxy/proxy_cli.py | 11 +- litellm/proxy/proxy_server.py | 6 +- .../spend_management_endpoints.py | 6 +- litellm/proxy/utils.py | 12 +- .../vertex_ai_endpoints/langfuse_endpoints.py | 2 +- litellm/router.py | 2 +- .../router_strategy/adaptive_router/hooks.py | 2 +- .../adaptive_router/signals.py | 1 + litellm/router_strategy/budget_limiter.py | 8 +- litellm/router_utils/get_retry_from_policy.py | 2 +- .../router_utils/pattern_match_deployments.py | 2 +- .../track_deployment_metrics.py | 2 +- litellm/secret_managers/aws_secret_manager.py | 2 +- .../secret_managers/aws_secret_manager_v2.py | 2 +- litellm/vector_store_files/utils.py | 4 +- pyproject.toml | 8 +- scripts/adaptive_router_demo/eval.py | 54 +- scripts/adaptive_router_demo/traffic.py | 59 +- scripts/benchmark_mock.py | 1 + scripts/benchmark_proxy_vs_provider.py | 2 +- scripts/eval_compression.py | 180 +++-- .../benchmark_get_all_latest_health_checks.py | 2 +- .../test_batch_custom_pricing.py | 1 + .../test_hosted_vllm_batches_and_files.py | 1 + tests/benchmarks/test_benchmarks.py | 1 + .../check_fastuuid_usage.py | 1 + tests/enterprise/conftest.py | 2 + .../integrations/test_prometheus.py | 86 +-- .../proxy/auth/test_user_api_key_auth.py | 11 +- .../test_apply_guardrail_endpoint.py | 4 +- .../proxy/hooks/test_managed_files.py | 674 ++++++++---------- .../test_internal_user_endpoints.py | 63 +- .../test_project_endpoints_prisma.py | 4 +- .../guardrails_tests/test_akto_guardrails.py | 1 + .../guardrails_tests/test_custom_guardrail.py | 1 + .../test_eu_ai_act_article5.py | 1 + .../test_sg_mas_ai_guardrails.py | 1 + .../test_sg_pdpa_guardrails.py | 1 + .../image_gen_tests/test_image_generation.py | 1 + tests/image_gen_tests/test_image_variation.py | 1 + .../llms/bedrock/embed/test_embedding.py | 1 + .../llms/bedrock/test_nova_imported_models.py | 1 + .../vertex_ai/test_gemini_batch_embeddings.py | 16 +- tests/litellm/test_bedrock_nemotron_super.py | 1 + .../test_bedrock_converse_dedup_factory.py | 19 +- .../test_aws_secret_manager.py | 1 + .../test_litellm_overhead.py | 1 + .../test_anthropic_responses_api.py | 1 + ...est_anthropic_tool_result_empty_call_id.py | 2 +- ...t_base_responses_api_streaming_iterator.py | 4 +- .../test_responses_hooks.py | 4 +- .../test_bedrock_anthropic_regression.py | 8 +- .../test_bedrock_invoke_tests.py | 1 + tests/llm_translation/test_crusoe.py | 7 +- tests/llm_translation/test_gemini.py | 21 +- .../test_gemini_image_usage.py | 2 +- .../create_mock_standard_logging_payload.py | 1 + .../test_amazing_vertex_completion.py | 9 +- tests/local_testing/test_cache_preset_key.py | 2 +- tests/local_testing/test_caching.py | 7 +- tests/local_testing/test_completion.py | 6 +- .../test_configs/custom_callbacks.py | 6 +- tests/local_testing/test_get_llm_provider.py | 1 + tests/local_testing/test_langsmith.py | 1 + tests/local_testing/test_router_caching.py | 1 + .../create_mock_standard_logging_payload.py | 1 + .../test_bedrock_knowledgebase_hook.py | 1 + .../test_datadog_llm_obs.py | 1 + .../test_gcs_pub_sub.py | 1 + .../test_generic_api_callback.py | 1 + .../test_langsmith_unit_test.py | 1 + .../test_log_db_redis_services.py | 1 + .../test_view_request_resp_logs.py | 1 + tests/mcp_tests/test_mcp_logging.py | 1 + tests/mcp_tests/test_mcp_server.py | 1 + tests/mcp_tests/test_proxy_mcp_e2e.py | 1 + tests/ocr_tests/base_ocr_unit_tests.py | 1 + .../test_ocr_azure_document_intelligence.py | 5 +- .../tests/bursty_load_test_completion.py | 1 + .../tests/load_test_embedding_100.py | 1 + .../test_openai_request_with_traceparent.py | 1 + .../test_openai_batches_endpoint.py | 1 + .../test_openai_files_endpoints.py | 1 + tests/otel_tests/test_e2e_model_access.py | 1 + .../test_team_member_permissions.py | 12 +- .../test_openai_assistants_passthrough.py | 1 + tests/pass_through_tests/test_vertex_ai.py | 1 + ..._anthropic_messages_prompt_caching_test.py | 8 +- .../test_route_check_unit_tests.py | 1 + .../test_usage_endpoints.py | 4 +- .../test_claude_agent_sdk.py | 1 + tests/proxy_unit_tests/conftest.py | 1 + .../test_configs/custom_callbacks.py | 6 +- .../proxy_unit_tests/test_jwt_key_mapping.py | 1 + .../test_search_api_logging.py | 2 +- .../create_mock_standard_logging_payload.py | 1 + .../test_router_handle_error.py | 1 + tests/search_tests/test_google_pse_search.py | 1 + .../test_callbacks_in_db.py | 4 +- .../test_bedrock_agentcore_a2a.py | 1 + .../send_emails/test_base_email.py | 179 ++--- .../send_emails/test_resend_email.py | 11 +- .../send_emails/test_sendgrid_email.py | 6 +- .../test_callback_controls.py | 328 +++------ .../test_batch_retrieve_input_file_id.py | 23 +- ..._retrieve_returns_unified_input_file_id.py | 4 +- .../proxy/test_enterprise_routes.py | 48 +- .../proxy/test_file_deletion_blocking.py | 188 +++-- .../proxy/test_managed_files_hook.py | 18 +- .../google_genai/test_google_genai_adapter.py | 1 - .../test_google_genai_adapter_fixes.py | 1 - .../google_genai/test_google_genai_handler.py | 1 - .../google_genai/test_google_genai_main.py | 1 - .../test_google_genai_transformation.py | 1 - .../arize/test_arize_otel_coexistence.py | 1 + .../dotprompt/test_prompt_manager.py | 18 +- .../integrations/test_custom_guardrail.py | 9 +- .../test_websearch_thinking_constraint.py | 1 + .../test_azure_assistant_cost_tracking.py | 2 +- .../litellm_core_utils/test_core_helpers.py | 24 +- .../test_streaming_handler.py | 14 +- .../litellm_core_utils/test_token_counter.py | 1 + ...al_pass_through_adapters_transformation.py | 4 +- .../test_agentic_streaming_iterator.py | 1 + .../messages/test_parallel_tool_calls.py | 1 + .../test_reasoning_auto_summary_messages.py | 22 +- .../llms/azure/test_azure_cost_calculation.py | 1 + .../azure_ai/test_azure_ai_cost_calculator.py | 25 +- .../test_managed_resource_isolation.py | 1 + .../llms/bedrock/chat/test_invoke_handler.py | 1 + tests/test_litellm/llms/crusoe/test_crusoe.py | 13 +- .../custom_httpx/test_aiohttp_so_keepalive.py | 5 +- .../llms/custom_httpx/test_mock_transport.py | 1 + .../llms/databricks/test_databricks_e2e.py | 6 +- ...gram_audio_transcription_transformation.py | 2 +- .../test_github_copilot_authenticator.py | 64 +- .../test_huggingface_embedding_handler.py | 1 + .../test_moonshot_chat_transformation.py | 4 +- .../test_perplexity_cost_calculator.py | 2 +- .../perplexity/test_perplexity_integration.py | 2 +- ...eway_audio_transcription_transformation.py | 1 + .../llms/test_polling_url_origin_match.py | 1 + .../llms/test_predibase_transformation.py | 25 +- .../gemini/test_context_circulation.py | 1 + ...test_vertex_and_google_ai_studio_gemini.py | 29 +- ...test_batch_embed_content_transformation.py | 14 +- .../vertex_ai/test_vertex_ai_common_utils.py | 6 +- .../test_vertex_model_garden_openapi.py | 5 +- .../auth/test_user_api_key_auth_mcp.py | 4 +- .../mcp_server/test_db_credentials.py | 1 + .../mcp_server/test_is_tool_name_prefixed.py | 1 + .../mcp_server/test_mcp_stale_session.py | 116 ++- .../mcp_server/test_semantic_tool_filter.py | 28 +- .../test_agent_header_isolation.py | 1 + .../agent_endpoints/test_agent_headers.py | 1 + .../proxy/auth/test_onboarding.py | 1 + .../proxy/common_utils/test_cache_codec.py | 8 +- tests/test_litellm/proxy/conftest.py | 1 + .../test_google_api_endpoints.py | 1 - .../content_filter/test_content_filter.py | 6 +- .../openai/test_moderations.py | 1 - .../proxy/guardrails/test_content_utils.py | 1 + .../guardrails/test_custom_code_security.py | 1 + .../proxy/guardrails/test_llm_as_a_judge.py | 61 +- .../test_qostodian_nexus_guardrail.py | 6 +- .../proxy/hooks/test_batch_file_validation.py | 1 + .../test_activity_tenant_scoping.py | 1 + .../test_budget_endpoints.py | 1 + .../test_delete_verification_tokens_failed.py | 1 + .../test_project_org_authz.py | 1 + .../test_team_default_params.py | 1 + .../test_workflow_management_endpoints.py | 1 + .../usage_endpoints/test_ai_usage_chat.py | 1 + .../test_llm_pass_through_endpoints.py | 6 +- .../proxy/test_fastapi_offline_routes.py | 2 +- ...test_filter_models_by_team_access_group.py | 4 +- .../proxy/test_model_level_guardrails.py | 1 + .../proxy/test_redis_auth_cache_flag.py | 1 + .../test_litellm/proxy/test_team_org_move.py | 31 +- .../test_responses_api_bridge_flag.py | 16 +- .../adaptive_router/test_router_dispatch.py | 4 +- .../test_acompletion_session_reuse_e2e.py | 1 + .../test_anthropic_skills_transformation.py | 1 + .../test_dashscope_image_generation.py | 1 + .../test_deepseek_model_metadata.py | 1 + tests/test_litellm/test_model_cost_aliases.py | 1 + tests/test_litellm/test_nested_drop_params.py | 1 + ...penai_embedding_encoding_format_default.py | 4 +- .../test_litellm/test_router_google_genai.py | 1 - .../test_streaming_connection_cleanup.py | 1 + tests/test_litellm/test_utils.py | 24 +- tests/test_litellm/types/test_completion.py | 2 +- .../test_prometheus_label_value_sanitize.py | 3 +- tests/test_passthrough_endpoints.py | 1 + .../vector_store_tests/rag/base_rag_tests.py | 4 +- .../rag/test_rag_vertex_ai.py | 8 +- .../test_milvus_vector_store.py | 1 + .../test_vertex_ai_search_api_vector_store.py | 1 + .../fixtures/mock_llm_server/server.py | 1 + 292 files changed, 1814 insertions(+), 2305 deletions(-) diff --git a/.github/workflows/auto_update_price_and_context_window_file.py b/.github/workflows/auto_update_price_and_context_window_file.py index b92a0568e3..461d8d347d 100644 --- a/.github/workflows/auto_update_price_and_context_window_file.py +++ b/.github/workflows/auto_update_price_and_context_window_file.py @@ -2,7 +2,6 @@ import asyncio import aiohttp import json - # Asynchronously fetch data from a given URL async def fetch_data(url): try: @@ -16,24 +15,22 @@ async def fetch_data(url): resp_json = await resp.json() print("Fetch the data from URL.") # Return the 'data' field from the JSON response - return resp_json["data"] + return resp_json['data'] except Exception as e: # Print an error message if fetching data fails print("Error fetching data from URL:", e) return None - # Synchronize local data with remote data def sync_local_data_with_remote(local_data, remote_data): # Update existing keys in local_data with values from remote_data - for key in set(local_data) & set(remote_data): + for key in (set(local_data) & set(remote_data)): local_data[key].update(remote_data[key]) # Add new keys from remote_data to local_data - for key in set(remote_data) - set(local_data): + for key in (set(remote_data) - set(local_data)): local_data[key] = remote_data[key] - # Write data to the json file def write_to_file(file_path, data): try: @@ -46,7 +43,6 @@ def write_to_file(file_path, data): # Print an error message if writing to file fails print("Error updating JSON file:", e) - # Update the existing models and add the missing models for OpenRouter def transform_openrouter_data(data): transformed = {} @@ -58,41 +54,33 @@ def transform_openrouter_data(data): } # Add 'max_output_tokens' as a field if it is not None - if ( - "top_provider" in row - and "max_completion_tokens" in row["top_provider"] - and row["top_provider"]["max_completion_tokens"] is not None - ): - obj["max_output_tokens"] = int(row["top_provider"]["max_completion_tokens"]) + if "top_provider" in row and "max_completion_tokens" in row["top_provider"] and row["top_provider"]["max_completion_tokens"] is not None: + obj['max_output_tokens'] = int(row["top_provider"]["max_completion_tokens"]) # Add the field 'output_cost_per_token' - obj.update( - { - "output_cost_per_token": float(row["pricing"]["completion"]), - } - ) + obj.update({ + "output_cost_per_token": float(row["pricing"]["completion"]), + }) # Add field 'input_cost_per_image' if it exists and is non-zero - if ( - "pricing" in row - and "image" in row["pricing"] - and float(row["pricing"]["image"]) != 0.0 - ): - obj["input_cost_per_image"] = float(row["pricing"]["image"]) + if "pricing" in row and "image" in row["pricing"] and float(row["pricing"]["image"]) != 0.0: + obj['input_cost_per_image'] = float(row["pricing"]["image"]) # Add the fields 'litellm_provider' and 'mode' - obj.update({"litellm_provider": "openrouter", "mode": "chat"}) + obj.update({ + "litellm_provider": "openrouter", + "mode": "chat" + }) # Add the 'supports_vision' field if the modality is 'multimodal' - if row.get("architecture", {}).get("modality") == "multimodal": - obj["supports_vision"] = True + if row.get('architecture', {}).get('modality') == 'multimodal': + obj['supports_vision'] = True # Use a composite key to store the transformed object transformed[f'openrouter/{row["id"]}'] = obj return transformed - # Update the existing models and add the missing models for Vercel AI Gateway def transform_vercel_ai_gateway_data(data): transformed = {} @@ -101,30 +89,20 @@ def transform_vercel_ai_gateway_data(data): "max_tokens": row["context_window"], "input_cost_per_token": float(row["pricing"]["input"]), "output_cost_per_token": float(row["pricing"]["output"]), - "max_output_tokens": row["max_tokens"], - "max_input_tokens": row["context_window"], + 'max_output_tokens': row['max_tokens'], + 'max_input_tokens': row["context_window"], } # Handle cache pricing if available if "pricing" in row: - if ( - "input_cache_read" in row["pricing"] - and row["pricing"]["input_cache_read"] is not None - ): - obj["cache_read_input_token_cost"] = float( - f"{float(row['pricing']['input_cache_read']):e}" - ) - - if ( - "input_cache_write" in row["pricing"] - and row["pricing"]["input_cache_write"] is not None - ): - obj["cache_creation_input_token_cost"] = float( - f"{float(row['pricing']['input_cache_write']):e}" - ) + if "input_cache_read" in row["pricing"] and row["pricing"]["input_cache_read"] is not None: + obj['cache_read_input_token_cost'] = float(f"{float(row['pricing']['input_cache_read']):e}") + + if "input_cache_write" in row["pricing"] and row["pricing"]["input_cache_write"] is not None: + obj['cache_creation_input_token_cost'] = float(f"{float(row['pricing']['input_cache_write']):e}") mode = "embedding" if "embedding" in row["id"].lower() else "chat" - + obj.update({"litellm_provider": "vercel_ai_gateway", "mode": mode}) transformed[f'vercel_ai_gateway/{row["id"]}'] = obj @@ -148,31 +126,24 @@ def load_local_data(file_path): print("Error decoding JSON:", e) return None - def main(): - local_file_path = ( - "model_prices_and_context_window.json" # Path to the local data file - ) - openrouter_url = ( - "https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data - ) - vercel_ai_gateway_url = ( - "https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data - ) + local_file_path = "model_prices_and_context_window.json" # Path to the local data file + openrouter_url = "https://openrouter.ai/api/v1/models" # URL to fetch OpenRouter data + vercel_ai_gateway_url = "https://ai-gateway.vercel.sh/v1/models" # URL to fetch Vercel AI Gateway data # Load local data from file local_data = load_local_data(local_file_path) - + # Fetch OpenRouter data openrouter_data = asyncio.run(fetch_data(openrouter_url)) # Transform the fetched OpenRouter data openrouter_data = transform_openrouter_data(openrouter_data) - + # Fetch Vercel AI Gateway data vercel_data = asyncio.run(fetch_data(vercel_ai_gateway_url)) # Transform the fetched Vercel AI Gateway data vercel_data = transform_vercel_ai_gateway_data(vercel_data) - + # Combine both datasets all_remote_data = {**openrouter_data, **vercel_data} @@ -183,7 +154,6 @@ def main(): else: print("Failed to fetch model data from either local file or URL.") - # Entry point of the script if __name__ == "__main__": main() diff --git a/.github/workflows/run_llm_translation_tests.py b/.github/workflows/run_llm_translation_tests.py index 22be769a73..3f3a70efe9 100644 --- a/.github/workflows/run_llm_translation_tests.py +++ b/.github/workflows/run_llm_translation_tests.py @@ -16,75 +16,64 @@ from pathlib import Path import json from typing import Dict, List, Tuple, Optional - # ANSI color codes for terminal output class Colors: - GREEN = "\033[92m" - RED = "\033[91m" - YELLOW = "\033[93m" - BLUE = "\033[94m" - PURPLE = "\033[95m" - CYAN = "\033[96m" - RESET = "\033[0m" - BOLD = "\033[1m" - + GREEN = '\033[92m' + RED = '\033[91m' + YELLOW = '\033[93m' + BLUE = '\033[94m' + PURPLE = '\033[95m' + CYAN = '\033[96m' + RESET = '\033[0m' + BOLD = '\033[1m' def print_colored(message: str, color: str = Colors.RESET): """Print colored message to terminal""" print(f"{color}{message}{Colors.RESET}") - def get_provider_from_test_file(test_file: str) -> str: """Map test file names to provider names""" provider_mapping = { - "test_anthropic": "Anthropic", - "test_azure": "Azure", - "test_bedrock": "AWS Bedrock", - "test_openai": "OpenAI", - "test_vertex": "Google Vertex AI", - "test_gemini": "Google Vertex AI", - "test_cohere": "Cohere", - "test_databricks": "Databricks", - "test_groq": "Groq", - "test_together": "Together AI", - "test_mistral": "Mistral", - "test_deepseek": "DeepSeek", - "test_replicate": "Replicate", - "test_huggingface": "HuggingFace", - "test_fireworks": "Fireworks AI", - "test_perplexity": "Perplexity", - "test_cloudflare": "Cloudflare", - "test_voyage": "Voyage AI", - "test_xai": "xAI", - "test_nvidia": "NVIDIA", - "test_watsonx": "IBM watsonx", - "test_azure_ai": "Azure AI", - "test_snowflake": "Snowflake", - "test_infinity": "Infinity", - "test_jina": "Jina AI", - "test_deepgram": "Deepgram", - "test_clarifai": "Clarifai", - "test_triton": "Triton", + 'test_anthropic': 'Anthropic', + 'test_azure': 'Azure', + 'test_bedrock': 'AWS Bedrock', + 'test_openai': 'OpenAI', + 'test_vertex': 'Google Vertex AI', + 'test_gemini': 'Google Vertex AI', + 'test_cohere': 'Cohere', + 'test_databricks': 'Databricks', + 'test_groq': 'Groq', + 'test_together': 'Together AI', + 'test_mistral': 'Mistral', + 'test_deepseek': 'DeepSeek', + 'test_replicate': 'Replicate', + 'test_huggingface': 'HuggingFace', + 'test_fireworks': 'Fireworks AI', + 'test_perplexity': 'Perplexity', + 'test_cloudflare': 'Cloudflare', + 'test_voyage': 'Voyage AI', + 'test_xai': 'xAI', + 'test_nvidia': 'NVIDIA', + 'test_watsonx': 'IBM watsonx', + 'test_azure_ai': 'Azure AI', + 'test_snowflake': 'Snowflake', + 'test_infinity': 'Infinity', + 'test_jina': 'Jina AI', + 'test_deepgram': 'Deepgram', + 'test_clarifai': 'Clarifai', + 'test_triton': 'Triton', } - + for key, provider in provider_mapping.items(): if key in test_file: return provider - + # For cross-provider test files - if any( - name in test_file - for name in [ - "test_optional_params", - "test_prompt_factory", - "test_router", - "test_text_completion", - ] - ): - return f"Cross-Provider Tests ({test_file})" - - return "Other Tests" - + if any(name in test_file for name in ['test_optional_params', 'test_prompt_factory', + 'test_router', 'test_text_completion']): + return f'Cross-Provider Tests ({test_file})' + + return 'Other Tests' def format_duration(seconds: float) -> str: """Format duration in human-readable format""" @@ -100,355 +89,290 @@ def format_duration(seconds: float) -> str: return f"{hours}h {minutes}m" -def generate_markdown_report( - junit_xml_path: str, output_path: str, tag: str = None, commit: str = None -): +def generate_markdown_report(junit_xml_path: str, output_path: str, tag: str = None, commit: str = None): """Generate a beautiful markdown report from JUnit XML""" try: tree = ET.parse(junit_xml_path) root = tree.getroot() - + # Handle both testsuite and testsuites root - if root.tag == "testsuites": - suites = root.findall("testsuite") + if root.tag == 'testsuites': + suites = root.findall('testsuite') else: suites = [root] - + # Overall statistics total_tests = 0 total_failures = 0 total_errors = 0 total_skipped = 0 total_time = 0.0 - + # Provider breakdown - provider_stats = defaultdict( - lambda: {"passed": 0, "failed": 0, "skipped": 0, "errors": 0, "time": 0.0} - ) + provider_stats = defaultdict(lambda: {'passed': 0, 'failed': 0, 'skipped': 0, 'errors': 0, 'time': 0.0}) provider_tests = defaultdict(list) - + for suite in suites: - total_tests += int(suite.get("tests", 0)) - total_failures += int(suite.get("failures", 0)) - total_errors += int(suite.get("errors", 0)) - total_skipped += int(suite.get("skipped", 0)) - total_time += float(suite.get("time", 0)) - - for testcase in suite.findall("testcase"): - classname = testcase.get("classname", "") - test_name = testcase.get("name", "") - test_time = float(testcase.get("time", 0)) - + total_tests += int(suite.get('tests', 0)) + total_failures += int(suite.get('failures', 0)) + total_errors += int(suite.get('errors', 0)) + total_skipped += int(suite.get('skipped', 0)) + total_time += float(suite.get('time', 0)) + + for testcase in suite.findall('testcase'): + classname = testcase.get('classname', '') + test_name = testcase.get('name', '') + test_time = float(testcase.get('time', 0)) + # Extract test file name from classname - if "." in classname: - parts = classname.split(".") - test_file = parts[-2] if len(parts) > 1 else "unknown" + if '.' in classname: + parts = classname.split('.') + test_file = parts[-2] if len(parts) > 1 else 'unknown' else: - test_file = "unknown" - + test_file = 'unknown' + provider = get_provider_from_test_file(test_file) - provider_stats[provider]["time"] += test_time - + provider_stats[provider]['time'] += test_time + # Check test status - if testcase.find("failure") is not None: - provider_stats[provider]["failed"] += 1 - failure = testcase.find("failure") - failure_msg = ( - failure.get("message", "") if failure is not None else "" - ) - provider_tests[provider].append( - { - "name": test_name, - "status": "FAILED", - "time": test_time, - "message": failure_msg, - } - ) - elif testcase.find("error") is not None: - provider_stats[provider]["errors"] += 1 - error = testcase.find("error") - error_msg = error.get("message", "") if error is not None else "" - provider_tests[provider].append( - { - "name": test_name, - "status": "ERROR", - "time": test_time, - "message": error_msg, - } - ) - elif testcase.find("skipped") is not None: - provider_stats[provider]["skipped"] += 1 - skip = testcase.find("skipped") - skip_msg = skip.get("message", "") if skip is not None else "" - provider_tests[provider].append( - { - "name": test_name, - "status": "SKIPPED", - "time": test_time, - "message": skip_msg, - } - ) + if testcase.find('failure') is not None: + provider_stats[provider]['failed'] += 1 + failure = testcase.find('failure') + failure_msg = failure.get('message', '') if failure is not None else '' + provider_tests[provider].append({ + 'name': test_name, + 'status': 'FAILED', + 'time': test_time, + 'message': failure_msg + }) + elif testcase.find('error') is not None: + provider_stats[provider]['errors'] += 1 + error = testcase.find('error') + error_msg = error.get('message', '') if error is not None else '' + provider_tests[provider].append({ + 'name': test_name, + 'status': 'ERROR', + 'time': test_time, + 'message': error_msg + }) + elif testcase.find('skipped') is not None: + provider_stats[provider]['skipped'] += 1 + skip = testcase.find('skipped') + skip_msg = skip.get('message', '') if skip is not None else '' + provider_tests[provider].append({ + 'name': test_name, + 'status': 'SKIPPED', + 'time': test_time, + 'message': skip_msg + }) else: - provider_stats[provider]["passed"] += 1 - provider_tests[provider].append( - { - "name": test_name, - "status": "PASSED", - "time": test_time, - "message": "", - } - ) - + provider_stats[provider]['passed'] += 1 + provider_tests[provider].append({ + 'name': test_name, + 'status': 'PASSED', + 'time': test_time, + 'message': '' + }) + passed = total_tests - total_failures - total_errors - total_skipped - + # Generate the markdown report - with open(output_path, "w") as f: + with open(output_path, 'w') as f: # Header f.write("# LLM Translation Test Results\n\n") - + # Metadata table f.write("## Test Run Information\n\n") f.write("| Field | Value |\n") f.write("|-------|-------|\n") f.write(f"| **Tag** | `{tag or 'N/A'}` |\n") - f.write( - f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n" - ) + f.write(f"| **Date** | {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')} |\n") f.write(f"| **Commit** | `{commit or 'N/A'}` |\n") f.write(f"| **Duration** | {format_duration(total_time)} |\n") f.write("\n") - + # Overall statistics with visual elements f.write("## Overall Statistics\n\n") - + # Summary box f.write("```\n") f.write(f"Total Tests: {total_tests}\n") - f.write( - f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n" - ) - f.write( - f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n" - ) - f.write( - f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n" - ) - f.write( - f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n" - ) + f.write(f"├── Passed: {passed:>4} ({(passed/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n") + f.write(f"├── Failed: {total_failures:>4} ({(total_failures/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n") + f.write(f"├── Errors: {total_errors:>4} ({(total_errors/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n") + f.write(f"└── Skipped: {total_skipped:>4} ({(total_skipped/total_tests)*100 if total_tests > 0 else 0:.1f}%)\n") f.write("```\n\n") - + + # Provider summary table f.write("## Results by Provider\n\n") - f.write( - "| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n" - ) - f.write( - "|----------|-------|------|------|-------|------|-----------|----------|" - ) - + f.write("| Provider | Total | Pass | Fail | Error | Skip | Pass Rate | Duration |\n") + f.write("|----------|-------|------|------|-------|------|-----------|----------|") + # Sort providers: specific providers first, then cross-provider tests sorted_providers = [] cross_provider = [] for p in sorted(provider_stats.keys()): - if "Cross-Provider" in p or p == "Other Tests": + if 'Cross-Provider' in p or p == 'Other Tests': cross_provider.append(p) else: sorted_providers.append(p) - + all_providers = sorted_providers + cross_provider - + for provider in all_providers: stats = provider_stats[provider] - total = ( - stats["passed"] - + stats["failed"] - + stats["errors"] - + stats["skipped"] - ) - pass_rate = (stats["passed"] / total * 100) if total > 0 else 0 - - f.write( - f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | " - ) + total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped'] + pass_rate = (stats['passed'] / total * 100) if total > 0 else 0 + + f.write(f"\n| {provider} | {total} | {stats['passed']} | {stats['failed']} | ") f.write(f"{stats['errors']} | {stats['skipped']} | {pass_rate:.1f}% | ") f.write(f"{format_duration(stats['time'])} |") - + # Detailed test results by provider f.write("\n\n## Detailed Test Results\n\n") - + for provider in sorted_providers: if provider_tests[provider]: stats = provider_stats[provider] - total = ( - stats["passed"] - + stats["failed"] - + stats["errors"] - + stats["skipped"] - ) - + total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped'] + f.write(f"### {provider}\n\n") f.write(f"**Summary:** {stats['passed']}/{total} passed ") - f.write( - f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) " - ) + f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%) ") f.write(f"in {format_duration(stats['time'])}\n\n") - + # Group tests by status tests_by_status = defaultdict(list) for test in provider_tests[provider]: - tests_by_status[test["status"]].append(test) - + tests_by_status[test['status']].append(test) + # Show failed tests first (if any) - if tests_by_status["FAILED"]: + if tests_by_status['FAILED']: f.write("
\nFailed Tests\n\n") - for test in tests_by_status["FAILED"]: + for test in tests_by_status['FAILED']: f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n") - if test["message"]: + if test['message']: # Truncate long error messages - msg = ( - test["message"][:200] + "..." - if len(test["message"]) > 200 - else test["message"] - ) + msg = test['message'][:200] + '...' if len(test['message']) > 200 else test['message'] f.write(f" > {msg}\n") f.write("\n
\n\n") - + # Show errors (if any) - if tests_by_status["ERROR"]: + if tests_by_status['ERROR']: f.write("
\nError Tests\n\n") - for test in tests_by_status["ERROR"]: + for test in tests_by_status['ERROR']: f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n") f.write("\n
\n\n") - + # Show passed tests in collapsible section - if tests_by_status["PASSED"]: + if tests_by_status['PASSED']: f.write("
\nPassed Tests\n\n") - for test in tests_by_status["PASSED"]: + for test in tests_by_status['PASSED']: f.write(f"- `{test['name']}` ({test['time']:.2f}s)\n") f.write("\n
\n\n") - + # Show skipped tests (if any) - if tests_by_status["SKIPPED"]: + if tests_by_status['SKIPPED']: f.write("
\nSkipped Tests\n\n") - for test in tests_by_status["SKIPPED"]: + for test in tests_by_status['SKIPPED']: f.write(f"- `{test['name']}`\n") f.write("\n
\n\n") - + # Cross-provider tests in a separate section if cross_provider: f.write("### Cross-Provider Tests\n\n") for provider in cross_provider: if provider_tests[provider]: stats = provider_stats[provider] - total = ( - stats["passed"] - + stats["failed"] - + stats["errors"] - + stats["skipped"] - ) - + total = stats['passed'] + stats['failed'] + stats['errors'] + stats['skipped'] + f.write(f"#### {provider}\n\n") f.write(f"**Summary:** {stats['passed']}/{total} passed ") - f.write( - f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n" - ) - + f.write(f"({(stats['passed']/total)*100 if total > 0 else 0:.1f}%)\n\n") + # For cross-provider tests, just show counts f.write(f"- Passed: {stats['passed']}\n") - if stats["failed"] > 0: + if stats['failed'] > 0: f.write(f"- Failed: {stats['failed']}\n") - if stats["errors"] > 0: + if stats['errors'] > 0: f.write(f"- Errors: {stats['errors']}\n") - if stats["skipped"] > 0: + if stats['skipped'] > 0: f.write(f"- Skipped: {stats['skipped']}\n") f.write("\n") - + + print_colored(f"Report generated: {output_path}", Colors.GREEN) - + except Exception as e: print_colored(f"Error generating report: {e}", Colors.RED) raise - -def run_tests( - test_path: str = "tests/llm_translation/", - junit_xml: str = "test-results/junit.xml", - report_path: str = "test-results/llm_translation_report.md", - tag: str = None, - commit: str = None, -) -> int: +def run_tests(test_path: str = "tests/llm_translation/", + junit_xml: str = "test-results/junit.xml", + report_path: str = "test-results/llm_translation_report.md", + tag: str = None, + commit: str = None) -> int: """Run the LLM translation tests and generate report""" - + # Create test results directory os.makedirs(os.path.dirname(junit_xml), exist_ok=True) - + print_colored("Starting LLM Translation Tests", Colors.BOLD + Colors.BLUE) print_colored(f"Test directory: {test_path}", Colors.CYAN) print_colored(f"Output: {junit_xml}", Colors.CYAN) print() - + # Run pytest cmd = [ - "uv", - "run", - "--no-sync", - "pytest", - test_path, + "uv", "run", "--no-sync", "pytest", test_path, f"--junitxml={junit_xml}", "-v", "--tb=short", "--maxfail=500", - "-n", - "auto", + "-n", "auto" ] - + # Add timeout if pytest-timeout is installed try: - subprocess.run( - ["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"], - capture_output=True, - check=True, - ) + subprocess.run(["uv", "run", "--no-sync", "python", "-c", "import pytest_timeout"], + capture_output=True, check=True) cmd.extend(["--timeout=300"]) except: - print_colored( - "Warning: pytest-timeout not installed, skipping timeout option", - Colors.YELLOW, - ) - + print_colored("Warning: pytest-timeout not installed, skipping timeout option", Colors.YELLOW) + print_colored("Running pytest with command:", Colors.YELLOW) print(f" {' '.join(cmd)}") print() - + # Run the tests result = subprocess.run(cmd, capture_output=False) - + # Generate the report regardless of test outcome if os.path.exists(junit_xml): print() print_colored("Generating test report...", Colors.BLUE) generate_markdown_report(junit_xml, report_path, tag, commit) - + # Print summary to console print() print_colored("Test Summary:", Colors.BOLD + Colors.PURPLE) - + # Parse XML for quick summary tree = ET.parse(junit_xml) root = tree.getroot() - - if root.tag == "testsuites": - suites = root.findall("testsuite") + + if root.tag == 'testsuites': + suites = root.findall('testsuite') else: suites = [root] - - total = sum(int(s.get("tests", 0)) for s in suites) - failures = sum(int(s.get("failures", 0)) for s in suites) - errors = sum(int(s.get("errors", 0)) for s in suites) - skipped = sum(int(s.get("skipped", 0)) for s in suites) + + total = sum(int(s.get('tests', 0)) for s in suites) + failures = sum(int(s.get('failures', 0)) for s in suites) + errors = sum(int(s.get('errors', 0)) for s in suites) + skipped = sum(int(s.get('skipped', 0)) for s in suites) passed = total - failures - errors - skipped - + print(f" Total: {total}") print_colored(f" Passed: {passed}", Colors.GREEN) if failures > 0: @@ -457,75 +381,59 @@ def run_tests( print_colored(f" Errors: {errors}", Colors.RED) if skipped > 0: print_colored(f" Skipped: {skipped}", Colors.YELLOW) - + if total > 0: pass_rate = (passed / total) * 100 - color = ( - Colors.GREEN - if pass_rate >= 80 - else Colors.YELLOW if pass_rate >= 60 else Colors.RED - ) + color = Colors.GREEN if pass_rate >= 80 else Colors.YELLOW if pass_rate >= 60 else Colors.RED print_colored(f" Pass Rate: {pass_rate:.1f}%", color) else: print_colored("No test results found!", Colors.RED) - + print() print_colored("Test run complete!", Colors.BOLD + Colors.GREEN) - + return result.returncode - if __name__ == "__main__": import argparse - + parser = argparse.ArgumentParser(description="Run LLM Translation Tests") - parser.add_argument( - "--test-path", default="tests/llm_translation/", help="Path to test directory" - ) - parser.add_argument( - "--junit-xml", - default="test-results/junit.xml", - help="Path for JUnit XML output", - ) - parser.add_argument( - "--report", - default="test-results/llm_translation_report.md", - help="Path for markdown report", - ) + parser.add_argument("--test-path", default="tests/llm_translation/", + help="Path to test directory") + parser.add_argument("--junit-xml", default="test-results/junit.xml", + help="Path for JUnit XML output") + parser.add_argument("--report", default="test-results/llm_translation_report.md", + help="Path for markdown report") parser.add_argument("--tag", help="Git tag or version") parser.add_argument("--commit", help="Git commit SHA") - + args = parser.parse_args() - + # Get git info if not provided if not args.commit: try: - result = subprocess.run( - ["git", "rev-parse", "HEAD"], capture_output=True, text=True - ) + result = subprocess.run(["git", "rev-parse", "HEAD"], + capture_output=True, text=True) if result.returncode == 0: args.commit = result.stdout.strip() except: pass - + if not args.tag: try: - result = subprocess.run( - ["git", "describe", "--tags", "--abbrev=0"], - capture_output=True, - text=True, - ) + result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"], + capture_output=True, text=True) if result.returncode == 0: args.tag = result.stdout.strip() except: pass - + exit_code = run_tests( test_path=args.test_path, junit_xml=args.junit_xml, report_path=args.report, tag=args.tag, - commit=args.commit, + commit=args.commit ) - + sys.exit(exit_code) diff --git a/ci_cd/run_migration.py b/ci_cd/run_migration.py index 760aee66da..feec4046ee 100644 --- a/ci_cd/run_migration.py +++ b/ci_cd/run_migration.py @@ -9,6 +9,7 @@ from pathlib import Path import testing.postgresql + DESTRUCTIVE_PATTERN = re.compile(r"\bDROP\s+(COLUMN|TABLE|INDEX)\b", re.IGNORECASE) DEFAULT_BASE_BRANCH = "litellm_internal_staging" diff --git a/cookbook/benchmark/benchmark.py b/cookbook/benchmark/benchmark.py index 2b3ddacad8..b38d185a16 100644 --- a/cookbook/benchmark/benchmark.py +++ b/cookbook/benchmark/benchmark.py @@ -6,6 +6,7 @@ from tabulate import tabulate from termcolor import colored import os + # Define the list of models to benchmark # select any LLM listed here: https://docs.litellm.ai/docs/providers models = ["gpt-3.5-turbo", "claude-2"] diff --git a/cookbook/litellm_proxy_server/braintrust_prompt_wrapper_server.py b/cookbook/litellm_proxy_server/braintrust_prompt_wrapper_server.py index a7049c68ce..6379314c5b 100644 --- a/cookbook/litellm_proxy_server/braintrust_prompt_wrapper_server.py +++ b/cookbook/litellm_proxy_server/braintrust_prompt_wrapper_server.py @@ -22,6 +22,7 @@ from fastapi import FastAPI, HTTPException, Header, Query from fastapi.responses import JSONResponse import uvicorn + app = FastAPI( title="Braintrust Prompt Wrapper", description="Wrapper server for Braintrust prompts to work with LiteLLM", diff --git a/cookbook/livekit_agent_sdk/main.py b/cookbook/livekit_agent_sdk/main.py index 26a42b40d0..c68e5534ea 100644 --- a/cookbook/livekit_agent_sdk/main.py +++ b/cookbook/livekit_agent_sdk/main.py @@ -2,7 +2,7 @@ Simple xAI Voice Agent using LiveKit SDK with LiteLLM Gateway This example shows how to use LiveKit's xAI realtime plugin through LiteLLM proxy. -LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI, +LiteLLM acts as a unified interface, allowing you to switch between xAI, OpenAI, and Azure realtime APIs without changing your agent code. """ diff --git a/cookbook/misc/migrate_proxy_config.py b/cookbook/misc/migrate_proxy_config.py index dc039070a5..31c3f32c08 100644 --- a/cookbook/misc/migrate_proxy_config.py +++ b/cookbook/misc/migrate_proxy_config.py @@ -1,14 +1,14 @@ """ LiteLLM Migration Script! -Takes a config.yaml and calls /model/new +Takes a config.yaml and calls /model/new Inputs: - File path to config.yaml - Proxy base url to your hosted proxy Step 1: Reads your config.yaml -Step 2: reads `model_list` and loops through all models +Step 2: reads `model_list` and loops through all models Step 3: calls `/model/new` for each model """ diff --git a/cookbook/mock_guardrail_server/mock_bedrock_guardrail_server.py b/cookbook/mock_guardrail_server/mock_bedrock_guardrail_server.py index 9f70bd20c2..7bf9cc3248 100644 --- a/cookbook/mock_guardrail_server/mock_bedrock_guardrail_server.py +++ b/cookbook/mock_guardrail_server/mock_bedrock_guardrail_server.py @@ -518,7 +518,8 @@ if __name__ == "__main__": print(f"Endpoint: POST /guardrail/{{id}}/version/{{version}}/apply") print("=" * 80) print("\nExample curl command:") - print(f""" + print( + f""" curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\ -H "Authorization: Bearer {bearer_token}" \\ -H "Content-Type: application/json" \\ @@ -532,7 +533,8 @@ curl -X POST "http://{host}:{port}/guardrail/test-guardrail/version/1/apply" \\ }} ] }}' - """) + """ + ) print("=" * 80) uvicorn.run(app, host=host, port=port) diff --git a/db_scripts/create_views.py b/db_scripts/create_views.py index 82f5b451db..3027b38958 100644 --- a/db_scripts/create_views.py +++ b/db_scripts/create_views.py @@ -34,7 +34,8 @@ async def check_view_exists(): # noqa: PLR0915 print("LiteLLM_VerificationTokenView Exists!") # noqa except Exception: # If an error occurs, the view does not exist, so create it - await db.execute_raw(""" + await db.execute_raw( + """ CREATE VIEW "LiteLLM_VerificationTokenView" AS SELECT v.*, @@ -44,7 +45,8 @@ async def check_view_exists(): # noqa: PLR0915 t.rpm_limit AS team_rpm_limit FROM "LiteLLM_VerificationToken" v LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; - """) + """ + ) print("LiteLLM_VerificationTokenView Created!") # noqa diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/callback_controls.py b/enterprise/litellm_enterprise/enterprise_callbacks/callback_controls.py index 7353b995d2..8824f4c02d 100644 --- a/enterprise/litellm_enterprise/enterprise_callbacks/callback_controls.py +++ b/enterprise/litellm_enterprise/enterprise_callbacks/callback_controls.py @@ -14,74 +14,53 @@ from litellm.types.utils import StandardCallbackDynamicParams class EnterpriseCallbackControls: @staticmethod def is_callback_disabled_dynamically( - callback: litellm.CALLBACK_TYPES, - litellm_params: dict, - standard_callback_dynamic_params: StandardCallbackDynamicParams, - ) -> bool: - """ - Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params. - - Args: - callback: The callback to check (can be string, CustomLogger instance, or callable) - litellm_params: Parameters containing proxy server request info - - Returns: - bool: True if the callback should be disabled, False otherwise - """ - from litellm.litellm_core_utils.custom_logger_registry import ( - CustomLoggerRegistry, - ) - - try: - disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks( - litellm_params, standard_callback_dynamic_params + callback: litellm.CALLBACK_TYPES, + litellm_params: dict, + standard_callback_dynamic_params: StandardCallbackDynamicParams + ) -> bool: + """ + Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params. + + Args: + callback: The callback to check (can be string, CustomLogger instance, or callable) + litellm_params: Parameters containing proxy server request info + + Returns: + bool: True if the callback should be disabled, False otherwise + """ + from litellm.litellm_core_utils.custom_logger_registry import ( + CustomLoggerRegistry, ) - verbose_logger.debug( - f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}" - ) - verbose_logger.debug( - f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}" - ) - if disabled_callbacks is not None: - ######################################################### - # premium user check - ######################################################### - if ( - not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling() - ): - return False - ######################################################### - if isinstance(callback, str): - if callback.lower() in disabled_callbacks: - verbose_logger.debug( - f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}" - ) - return True - elif isinstance(callback, CustomLogger): - # get the string name of the callback - callback_str = ( - CustomLoggerRegistry.get_callback_str_from_class_type( - callback.__class__ - ) - ) - if ( - callback_str is not None - and callback_str.lower() in disabled_callbacks - ): - verbose_logger.debug( - f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}" - ) - return True - return False - except Exception as e: - verbose_logger.debug(f"Error checking disabled callbacks header: {str(e)}") - return False + try: + disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params) + verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}") + verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}") + if disabled_callbacks is not None: + ######################################################### + # premium user check + ######################################################### + if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling(): + return False + ######################################################### + if isinstance(callback, str): + if callback.lower() in disabled_callbacks: + verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}") + return True + elif isinstance(callback, CustomLogger): + # get the string name of the callback + callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__) + if callback_str is not None and callback_str.lower() in disabled_callbacks: + verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}") + return True + return False + except Exception as e: + verbose_logger.debug( + f"Error checking disabled callbacks header: {str(e)}" + ) + return False @staticmethod - def get_disabled_callbacks( - litellm_params: dict, - standard_callback_dynamic_params: StandardCallbackDynamicParams, - ) -> Optional[List[str]]: + def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]: """ Get the disabled callbacks from the standard callback dynamic params. """ @@ -92,24 +71,18 @@ class EnterpriseCallbackControls: request_headers = get_proxy_server_request_headers(litellm_params) disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None) if disabled_callbacks is not None: - disabled_callbacks = set( - [cb.strip().lower() for cb in disabled_callbacks.split(",")] - ) + disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")]) return list(disabled_callbacks) + ######################################################### # check if disabled via request body ######################################################### - if ( - standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) - is not None - ): - return standard_callback_dynamic_params.get( - "litellm_disabled_callbacks", None - ) - + if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None: + return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) + return None - + @staticmethod def _should_allow_dynamic_callback_disabling(): import litellm @@ -117,14 +90,10 @@ class EnterpriseCallbackControls: # Check if admin has disabled this feature if litellm.allow_dynamic_callback_disabling is not True: - verbose_logger.debug( - "Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling" - ) + verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling") return False - + if premium_user: return True - verbose_logger.warning( - f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}" - ) - return False + verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}") + return False \ No newline at end of file diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py index 4fb6679a6e..89c3b85468 100644 --- a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py +++ b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py @@ -349,10 +349,8 @@ class BaseEmailLogger(CustomLogger): ) # Calculate percentage and alert threshold - percentage = ( - threshold_pct - if threshold_pct is not None - else int(EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100) + percentage = threshold_pct if threshold_pct is not None else int( + EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100 ) threshold_fraction = percentage / 100.0 alert_threshold_str = ( @@ -611,7 +609,9 @@ class BaseEmailLogger(CustomLogger): continue _id = user_info.token or user_info.user_id or "default_id" - _cache_key = f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}" + _cache_key = ( + f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}" + ) result = await _cache.async_get_cache(key=_cache_key) if result is not None: @@ -630,9 +630,7 @@ class BaseEmailLogger(CustomLogger): continue recipient_emails = list(set(emails)) - event_message = ( - f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached" - ) + event_message = f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached" webhook_event = WebhookEvent( event="max_budget_alert", event_message=event_message, diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/sendgrid_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/sendgrid_email.py index a1e8def2bb..8fc2d66d53 100644 --- a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/sendgrid_email.py +++ b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/sendgrid_email.py @@ -15,6 +15,7 @@ from litellm.llms.custom_httpx.http_handler import ( from .base_email import BaseEmailLogger + SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send" @@ -78,4 +79,4 @@ class SendGridEmailLogger(BaseEmailLogger): verbose_logger.debug( f"SendGrid response status={response.status_code}, body={response.text}" ) - return + return \ No newline at end of file diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/smtp_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/smtp_email.py index 8e4dbde437..8efdaf231b 100644 --- a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/smtp_email.py +++ b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/smtp_email.py @@ -1,7 +1,6 @@ """ This is the litellm SMTP email integration """ - import asyncio from typing import List diff --git a/enterprise/litellm_enterprise/litellm_core_utils/litellm_logging.py b/enterprise/litellm_enterprise/litellm_core_utils/litellm_logging.py index 24941e90ab..44ba0063ff 100644 --- a/enterprise/litellm_enterprise/litellm_core_utils/litellm_logging.py +++ b/enterprise/litellm_enterprise/litellm_core_utils/litellm_logging.py @@ -1,7 +1,6 @@ """ Enterprise specific logging utils """ - from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata diff --git a/enterprise/litellm_enterprise/proxy/audit_logging_endpoints.py b/enterprise/litellm_enterprise/proxy/audit_logging_endpoints.py index 4f2eaa3c46..18ac29b978 100644 --- a/enterprise/litellm_enterprise/proxy/audit_logging_endpoints.py +++ b/enterprise/litellm_enterprise/proxy/audit_logging_endpoints.py @@ -153,11 +153,11 @@ async def get_audit_logs( # Return paginated response return PaginatedAuditLogResponse( - audit_logs=( - [AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs] - if audit_logs - else [] - ), + audit_logs=[ + AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs + ] + if audit_logs + else [], total=total_count, page=page, page_size=page_size, diff --git a/enterprise/litellm_enterprise/proxy/auth/__init__.py b/enterprise/litellm_enterprise/proxy/auth/__init__.py index dc70b57ab5..f67826ca7f 100644 --- a/enterprise/litellm_enterprise/proxy/auth/__init__.py +++ b/enterprise/litellm_enterprise/proxy/auth/__init__.py @@ -7,4 +7,4 @@ including custom SSO handlers and advanced authentication features. from .custom_sso_handler import EnterpriseCustomSSOHandler -__all__ = ["EnterpriseCustomSSOHandler"] +__all__ = ["EnterpriseCustomSSOHandler"] \ No newline at end of file diff --git a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py index 3aef2bd77c..356f6ecd4b 100644 --- a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py +++ b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py @@ -53,9 +53,7 @@ class CheckBatchCost: "user_api_key_alias": getattr(user_row, "user_alias", None), } except Exception as e: - verbose_proxy_logger.error( - f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}" - ) + verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}") return {} async def _cleanup_stale_managed_objects(self) -> None: @@ -64,22 +62,11 @@ class CheckBatchCost: in non-terminal states as 'stale_expired'. These will never complete and should not be polled. """ - cutoff = datetime.now(timezone.utc) - timedelta( - days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS - ) + cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS) result = await self.prisma_client.db.litellm_managedobjecttable.update_many( where={ "file_purpose": "batch", - "status": { - "not_in": [ - "completed", - "complete", - "failed", - "expired", - "cancelled", - "stale_expired", - ] - }, + "status": {"not_in": ["completed", "complete", "failed", "expired", "cancelled", "stale_expired"]}, "created_at": {"lt": cutoff}, }, data={"status": "stale_expired"}, @@ -133,12 +120,9 @@ class CheckBatchCost: try: from litellm.integrations.prometheus import PrometheusLogger - prom_logger = PrometheusLogger.get_instance() except Exception as e: - verbose_proxy_logger.error( - f"CheckBatchCost: could not get Prometheus logger: {e}" - ) + verbose_proxy_logger.error(f"CheckBatchCost: could not get Prometheus logger: {e}") prom_logger = None processed_models: List[Tuple[Optional[str], Optional[str]]] = [] @@ -177,11 +161,7 @@ class CheckBatchCost: order={"created_at": "asc"}, ) except Exception as query_err: - if ( - "batch_processed" not in str(query_err).lower() - and "unknown column" not in str(query_err).lower() - and "does not exist" not in str(query_err).lower() - ): + if "batch_processed" not in str(query_err).lower() and "unknown column" not in str(query_err).lower() and "does not exist" not in str(query_err).lower(): raise # Permanent schema gap — cache the result so future cycles skip straight to fallback self._has_batch_processed_column = False @@ -236,13 +216,14 @@ class CheckBatchCost: f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}" ) if prom_logger: - prom_logger.record_check_batch_cost_error( - "provider_retrieval_error" - ) + prom_logger.record_check_batch_cost_error("provider_retrieval_error") continue ## RETRIEVE THE BATCH JOB OUTPUT FILE - if response.status == "completed" and response.output_file_id is not None: + if ( + response.status == "completed" + and response.output_file_id is not None + ): verbose_proxy_logger.info( f"Batch ID: {batch_id} is complete, tracking cost and usage" ) @@ -269,25 +250,20 @@ class CheckBatchCost: decoded = _is_base64_encoded_unified_file_id(raw_output_file_id) if decoded: try: - raw_output_file_id = decoded.split("llm_output_file_id,")[ - 1 - ].split(";")[0] + raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0] except (IndexError, AttributeError): pass - credentials = ( - self.llm_router.get_deployment_credentials_with_provider(model_id) - or {} - ) + credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {} _file_content = await afile_content( file_id=raw_output_file_id, **credentials, ) # Access content - handle both direct attribute and method call - if hasattr(_file_content, "content"): + if hasattr(_file_content, 'content'): content_bytes = _file_content.content # type: ignore[union-attr] - elif hasattr(_file_content, "read"): + elif hasattr(_file_content, 'read'): content_bytes = await _file_content.read() # type: ignore[misc] else: content_bytes = _file_content # type: ignore[assignment] @@ -314,9 +290,7 @@ class CheckBatchCost: f"Skipping job {unified_object_id} because it is not a valid deployment info" ) if prom_logger: - prom_logger.record_check_batch_cost_error( - "deployment_not_found" - ) + prom_logger.record_check_batch_cost_error("deployment_not_found") continue custom_llm_provider = deployment_info.litellm_params.custom_llm_provider litellm_model_name = deployment_info.litellm_params.model @@ -328,11 +302,7 @@ class CheckBatchCost: # Pass deployment model_info so custom batch pricing # (input_cost_per_token_batches etc.) is used for cost calc - deployment_model_info = ( - deployment_info.model_info.model_dump() - if deployment_info.model_info - else {} - ) + deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {} batch_cost, batch_usage, batch_models = ( await calculate_batch_cost_and_usage( file_content_dictionary=file_content_as_dict, @@ -379,9 +349,7 @@ class CheckBatchCost: # Record batch duration (completed_at - created_at) if prom_logger and response.completed_at and response.created_at: - duration_seconds = float( - response.completed_at - response.created_at - ) + duration_seconds = float(response.completed_at - response.created_at) if duration_seconds >= 0: prom_logger.record_managed_batch_duration( duration_seconds=duration_seconds, @@ -390,9 +358,7 @@ class CheckBatchCost: ) # Track this job for the final metrics summary - processed_models.append( - (model_name, str(llm_provider) if llm_provider else None) - ) + processed_models.append((model_name, str(llm_provider) if llm_provider else None)) # mark the job as complete try: diff --git a/enterprise/litellm_enterprise/proxy/common_utils/check_responses_cost.py b/enterprise/litellm_enterprise/proxy/common_utils/check_responses_cost.py index 8b2d15c157..dc0168683c 100644 --- a/enterprise/litellm_enterprise/proxy/common_utils/check_responses_cost.py +++ b/enterprise/litellm_enterprise/proxy/common_utils/check_responses_cost.py @@ -33,7 +33,9 @@ class CheckResponsesCost: self.prisma_client: PrismaClient = prisma_client self.llm_router: Router = llm_router - async def _expire_stale_rows(self, cutoff: datetime, batch_size: int) -> int: + async def _expire_stale_rows( + self, cutoff: datetime, batch_size: int + ) -> int: """Execute the bounded UPDATE that marks stale rows as 'stale_expired'. Isolated so it can be swapped / mocked in tests without touching the @@ -72,9 +74,7 @@ class CheckResponsesCost: rows per invocation to avoid overwhelming the DB when there is a large backlog. """ - cutoff = datetime.now(timezone.utc) - timedelta( - days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS - ) + cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS) result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE) if result > 0: verbose_proxy_logger.warning( @@ -105,7 +105,7 @@ class CheckResponsesCost: take=MAX_OBJECTS_PER_POLL_CYCLE, order={"created_at": "asc"}, ) - + verbose_proxy_logger.debug(f"Found {len(jobs)} response jobs to check") completed_jobs = [] @@ -120,33 +120,29 @@ class CheckResponsesCost: # Get the stored response object to extract model information stored_response = job.file_object model_name = stored_response.get("model", None) - + # Decrypt the response ID - responses_id_security, _, _ = ( - ResponsesIDSecurity()._decrypt_response_id(unified_object_id) - ) - + responses_id_security, _, _ = ResponsesIDSecurity()._decrypt_response_id(unified_object_id) + # Prepare metadata with model information for cost tracking litellm_metadata = { "user_api_key_user_id": job.created_by or "default-user-id", } - + # Add model information if available if model_name: litellm_metadata["model"] = model_name - litellm_metadata["model_group"] = ( - model_name # Use same value for model_group - ) - + litellm_metadata["model_group"] = model_name # Use same value for model_group + response = await litellm.aget_responses( response_id=responses_id_security, litellm_metadata=litellm_metadata, ) - + verbose_proxy_logger.debug( f"Response {unified_object_id} status: {response.status}, model: {model_name}" ) - + except Exception as e: verbose_proxy_logger.info( f"Skipping job {unified_object_id} due to error: {e}" @@ -159,7 +155,7 @@ class CheckResponsesCost: f"Response {unified_object_id} is complete. Cost automatically tracked by aget_responses." ) completed_jobs.append(job) - + elif response.status in ["failed", "cancelled"]: verbose_proxy_logger.info( f"Response {unified_object_id} has status {response.status}, marking as complete" @@ -175,3 +171,4 @@ class CheckResponsesCost: verbose_proxy_logger.info( f"Marked {len(completed_jobs)} response jobs as completed" ) + diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_vector_stores.py b/enterprise/litellm_enterprise/proxy/hooks/managed_vector_stores.py index 70634537c5..254d816039 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_vector_stores.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_vector_stores.py @@ -41,7 +41,7 @@ class _PROXY_LiteLLMManagedVectorStores( ): """ Managed vector stores with target_model_names support. - + This class provides functionality to: - Create vector stores across multiple models - Retrieve vector stores by unified ID @@ -77,14 +77,14 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> str: """ Generate the format string for the unified vector store ID. - + Format: litellm_proxy:vector_store;unified_id,;target_model_names,;resource_id,;model_id, """ # VectorStoreCreateResponse is a TypedDict, so resource_object is a dictionary # Extract provider resource ID from the response provider_resource_id = resource_object.get("id", "") - + # Model ID is stored in hidden params if the response object supports it # For TypedDict responses, we need to check if _hidden_params was added hidden_params: Dict[str, Any] = {} @@ -109,18 +109,20 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> VectorStoreCreateResponse: """ Create a vector store for a specific model. - + Args: llm_router: LiteLLM router instance model: Model name to create vector store for request_data: Request data for vector store creation litellm_parent_otel_span: OpenTelemetry span for tracing - + Returns: VectorStoreCreateResponse from the provider """ # Use the router to create the vector store - response = await llm_router.avector_store_create(model=model, **request_data) + response = await llm_router.avector_store_create( + model=model, **request_data + ) return response # ============================================================================ @@ -137,14 +139,14 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> VectorStoreCreateResponse: """ Create a vector store across multiple models. - + Args: create_request: Vector store creation request parameters llm_router: LiteLLM router instance target_model_names_list: List of target model names litellm_parent_otel_span: OpenTelemetry span for tracing user_api_key_dict: User API key authentication details - + Returns: VectorStoreCreateResponse with unified ID """ @@ -194,7 +196,7 @@ class _PROXY_LiteLLMManagedVectorStores( # VectorStoreCreateResponse is a TypedDict, so we need to create a new dict with the unified ID response = responses[0].copy() response["id"] = unified_id - + verbose_logger.info( f"Successfully created managed vector store with unified ID: {unified_id}" ) @@ -210,13 +212,13 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> Dict[str, Any]: """ List vector stores created by a user. - + Args: user_api_key_dict: User API key authentication details limit: Maximum number of vector stores to return after: Cursor for pagination order: Sort order ('asc' or 'desc') - + Returns: Dictionary with list of vector stores and pagination info """ @@ -236,23 +238,23 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> bool: """ Check if user has access to a vector store. - + Args: vector_store_id: The unified vector store ID user_api_key_dict: User API key authentication details - + Returns: True if user has access, False otherwise """ is_unified_id = is_base64_encoded_unified_id(vector_store_id) - + if is_unified_id: # Check access for managed vector store return await self.can_user_access_unified_resource_id( vector_store_id, user_api_key_dict, ) - + # Not a managed vector store, allow access return True @@ -261,22 +263,24 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> bool: """ Check if user has access to a managed vector store in request data. - + Args: data: Request data containing vector_store_id user_api_key_dict: User API key authentication details - + Returns: True if this is a managed vector store and user has access - + Raises: HTTPException: If user doesn't have access """ vector_store_id = cast(Optional[str], data.get("vector_store_id")) is_unified_id = ( - is_base64_encoded_unified_id(vector_store_id) if vector_store_id else False + is_base64_encoded_unified_id(vector_store_id) + if vector_store_id + else False ) - + if is_unified_id and vector_store_id: if await self.can_user_access_unified_resource_id( vector_store_id, user_api_key_dict @@ -287,7 +291,7 @@ class _PROXY_LiteLLMManagedVectorStores( status_code=403, detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}", ) - + return False # ============================================================================ @@ -303,18 +307,18 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> Union[Exception, str, Dict, None]: """ Pre-call hook to handle vector store operations. - + This hook intercepts vector store requests and: - Validates access for managed vector stores - Transforms unified IDs to provider-specific IDs - Adds model routing information - + Args: user_api_key_dict: User API key authentication details cache: Cache instance data: Request data call_type: Type of call being made - + Returns: Modified request data or None """ @@ -326,40 +330,40 @@ class _PROXY_LiteLLMManagedVectorStores( # Handle vector store search operations if call_type == "avector_store_search": vector_store_id = data.get("vector_store_id") - + if vector_store_id: # Check if it's a managed vector store ID decoded_id = is_base64_encoded_unified_id(vector_store_id) - + if decoded_id: verbose_logger.debug( f"Processing managed vector store search: {vector_store_id}" ) - + # Check access has_access = await self.can_user_access_unified_resource_id( vector_store_id, user_api_key_dict ) - + if not has_access: raise HTTPException( status_code=403, detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}", ) - + # Parse the unified ID to extract components parsed_id = parse_unified_id(vector_store_id) - + if parsed_id: # Extract the model ID and provider resource ID model_id = parsed_id.get("model_id") provider_resource_id = parsed_id.get("provider_resource_id") target_model_names = parsed_id.get("target_model_names", []) - + verbose_logger.debug( f"Decoded vector store - model_id: {model_id}, provider_resource_id: {provider_resource_id}, target_model_names: {target_model_names}" ) - + # Determine which model to use for routing # Priority: model_id (deployment ID) > first target_model_name routing_model = None @@ -367,28 +371,28 @@ class _PROXY_LiteLLMManagedVectorStores( routing_model = model_id elif target_model_names and len(target_model_names) > 0: routing_model = target_model_names[0] - + # Set the model for routing if routing_model: data["model"] = routing_model verbose_logger.info( f"Routing vector store search to model: {routing_model}" ) - + # Replace the unified ID with the provider-specific ID if provider_resource_id: data["vector_store_id"] = provider_resource_id verbose_logger.debug( f"Replaced unified ID with provider resource ID: {provider_resource_id}" ) - + # Handle vector store retrieve/delete operations elif call_type in ("avector_store_retrieve", "avector_store_delete"): await self.check_managed_vector_store_access(data, user_api_key_dict) - + # If it's a managed vector store, we'll handle it in the endpoint # No need to transform here as the endpoint will route to the hook - + return data # ============================================================================ @@ -403,15 +407,15 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> Any: """ Post-call hook to transform responses. - + This hook can be used to transform responses if needed. For now, it just passes through the response. - + Args: data: Request data user_api_key_dict: User API key authentication details response: Response from the provider - + Returns: Potentially modified response """ @@ -432,21 +436,21 @@ class _PROXY_LiteLLMManagedVectorStores( ) -> List[Dict]: """ Filter deployments based on vector store availability. - + This is used by the router to select only deployments that have the vector store available. - + Note: This method signature is a compromise between CustomLogger and BaseManagedResource parent classes which have incompatible signatures. The type: ignore[override] is necessary due to this multiple inheritance conflict. - + Args: model: Model name healthy_deployments: List of healthy deployments messages: Messages (unused for vector stores, required by CustomLogger interface) request_kwargs: Request kwargs containing vector_store_id and mappings parent_otel_span: OpenTelemetry span for tracing - + Returns: Filtered list of deployments """ diff --git a/enterprise/litellm_enterprise/proxy/management_endpoints/internal_user_endpoints.py b/enterprise/litellm_enterprise/proxy/management_endpoints/internal_user_endpoints.py index 48b6dd7634..2f53f9e928 100644 --- a/enterprise/litellm_enterprise/proxy/management_endpoints/internal_user_endpoints.py +++ b/enterprise/litellm_enterprise/proxy/management_endpoints/internal_user_endpoints.py @@ -2,6 +2,7 @@ Enterprise internal user management endpoints """ + from fastapi import APIRouter, Depends, HTTPException from litellm.proxy._types import UserAPIKeyAuth diff --git a/enterprise/litellm_enterprise/proxy/vector_stores/endpoints.py b/enterprise/litellm_enterprise/proxy/vector_stores/endpoints.py index cf8c38719d..5e79959986 100644 --- a/enterprise/litellm_enterprise/proxy/vector_stores/endpoints.py +++ b/enterprise/litellm_enterprise/proxy/vector_stores/endpoints.py @@ -147,12 +147,12 @@ async def list_vector_stores( vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db( prisma_client=prisma_client ) - + # Also clean up in-memory registry to remove any deleted vector stores if litellm.vector_store_registry is not None: db_vector_store_ids = { - vs.get("vector_store_id") - for vs in vector_stores_from_db + vs.get("vector_store_id") + for vs in vector_stores_from_db if vs.get("vector_store_id") } # Remove any in-memory vector stores that no longer exist in database diff --git a/enterprise/litellm_enterprise/types/enterprise_callbacks/send_emails.py b/enterprise/litellm_enterprise/types/enterprise_callbacks/send_emails.py index d9d5a989ab..380b0a6fac 100644 --- a/enterprise/litellm_enterprise/types/enterprise_callbacks/send_emails.py +++ b/enterprise/litellm_enterprise/types/enterprise_callbacks/send_emails.py @@ -39,23 +39,15 @@ class EmailEvent(str, enum.Enum): soft_budget_crossed = "Soft Budget Crossed" max_budget_alert = "Max Budget Alert" - class EmailEventSettings(BaseModel): event: EmailEvent enabled: bool - - class EmailEventSettingsUpdateRequest(BaseModel): settings: List[EmailEventSettings] - - class EmailEventSettingsResponse(BaseModel): settings: List[EmailEventSettings] - - class DefaultEmailSettings(BaseModel): """Default settings for email events""" - settings: Dict[EmailEvent, bool] = Field( default_factory=lambda: { EmailEvent.virtual_key_created: True, # On by default @@ -65,12 +57,10 @@ class DefaultEmailSettings(BaseModel): EmailEvent.max_budget_alert: True, # On by default } ) - def to_dict(self) -> Dict[str, bool]: """Convert to dictionary with string keys for storage""" return {event.value: enabled for event, enabled in self.settings.items()} - @classmethod def get_defaults(cls) -> Dict[str, bool]: """Get the default settings as a dictionary with string keys""" - return cls().to_dict() + return cls().to_dict() \ No newline at end of file diff --git a/litellm/_uuid.py b/litellm/_uuid.py index 2b7c3b82d3..52acf647dd 100644 --- a/litellm/_uuid.py +++ b/litellm/_uuid.py @@ -6,6 +6,7 @@ Always uses fastuuid for performance. import fastuuid as _uuid # type: ignore + # Expose a module-like alias so callers can use: uuid.uuid4() uuid = _uuid diff --git a/litellm/anthropic_interface/exceptions/exception_mapping_utils.py b/litellm/anthropic_interface/exceptions/exception_mapping_utils.py index 4548185bbd..28020e763f 100644 --- a/litellm/anthropic_interface/exceptions/exception_mapping_utils.py +++ b/litellm/anthropic_interface/exceptions/exception_mapping_utils.py @@ -9,6 +9,7 @@ from typing import Dict, Optional from .exceptions import AnthropicErrorResponse, AnthropicErrorType + # HTTP status code -> Anthropic error type # Source: https://docs.anthropic.com/en/api/errors ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = { diff --git a/litellm/anthropic_interface/exceptions/exceptions.py b/litellm/anthropic_interface/exceptions/exceptions.py index b289e493e6..984390fa70 100644 --- a/litellm/anthropic_interface/exceptions/exceptions.py +++ b/litellm/anthropic_interface/exceptions/exceptions.py @@ -2,6 +2,7 @@ from typing_extensions import Literal, Required, TypedDict + # Known Anthropic error types # Source: https://docs.anthropic.com/en/api/errors AnthropicErrorType = Literal[ diff --git a/litellm/completion_extras/litellm_responses_transformation/transformation.py b/litellm/completion_extras/litellm_responses_transformation/transformation.py index 25860e72f2..da3b9184ed 100644 --- a/litellm/completion_extras/litellm_responses_transformation/transformation.py +++ b/litellm/completion_extras/litellm_responses_transformation/transformation.py @@ -97,7 +97,7 @@ def _build_reasoning_item( def _reasoning_item_to_response_input( - r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]], + r_item: Union[ChatCompletionReasoningItem, Dict[str, Any]] ) -> Dict[str, Any]: """Convert a stored ChatCompletionReasoningItem back to a Responses API input item.""" r_input: Dict[str, Any] = { diff --git a/litellm/compression/content_detection.py b/litellm/compression/content_detection.py index 975117eb60..0655a42daf 100644 --- a/litellm/compression/content_detection.py +++ b/litellm/compression/content_detection.py @@ -5,6 +5,7 @@ Auto-detect content type per message: code, JSON, or text. import json import re + _CODE_KEYWORDS = re.compile( r"\b(?:def |function |class |import |from |require\(|#include|fn |func |const |let |var |public |private |static )\b" ) diff --git a/litellm/files/types.py b/litellm/files/types.py index ba42a39f66..688bc86f0c 100644 --- a/litellm/files/types.py +++ b/litellm/files/types.py @@ -1,5 +1,6 @@ from typing import AsyncIterator, Dict, Iterator, Literal, NamedTuple, Union + FileContentProvider = Literal[ "openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus" ] diff --git a/litellm/google_genai/adapters/__init__.py b/litellm/google_genai/adapters/__init__.py index 6fbe7d95a5..bfa9e71267 100644 --- a/litellm/google_genai/adapters/__init__.py +++ b/litellm/google_genai/adapters/__init__.py @@ -1,10 +1,10 @@ """ Google GenAI Adapters for LiteLLM -This module provides adapters for transforming Google GenAI generate_content requests +This module provides adapters for transforming Google GenAI generate_content requests to/from LiteLLM completion format with full support for: - Text content transformation -- Tool calling (function declarations, function calls, function responses) +- Tool calling (function declarations, function calls, function responses) - Streaming (both regular and tool calling) - Mixed content (text + tool calls) """ diff --git a/litellm/integrations/SlackAlerting/batching_handler.py b/litellm/integrations/SlackAlerting/batching_handler.py index 828f3eb417..fdce2e0479 100644 --- a/litellm/integrations/SlackAlerting/batching_handler.py +++ b/litellm/integrations/SlackAlerting/batching_handler.py @@ -1,9 +1,9 @@ """ -Handles Batching + sending Httpx Post requests to slack +Handles Batching + sending Httpx Post requests to slack -Slack alerts are sent every 10s or when events are greater than X events +Slack alerts are sent every 10s or when events are greater than X events -see custom_batch_logger.py for more details / defaults +see custom_batch_logger.py for more details / defaults """ from typing import TYPE_CHECKING, Any diff --git a/litellm/integrations/SlackAlerting/utils.py b/litellm/integrations/SlackAlerting/utils.py index e258076817..e695266c88 100644 --- a/litellm/integrations/SlackAlerting/utils.py +++ b/litellm/integrations/SlackAlerting/utils.py @@ -18,7 +18,7 @@ else: def process_slack_alerting_variables( - alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]], + alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] ) -> Optional[Dict[AlertType, Union[List[str], str]]]: """ process alert_to_webhook_url diff --git a/litellm/integrations/additional_logging_utils.py b/litellm/integrations/additional_logging_utils.py index 59319140a1..795afd81d4 100644 --- a/litellm/integrations/additional_logging_utils.py +++ b/litellm/integrations/additional_logging_utils.py @@ -1,5 +1,5 @@ """ -Base class for Additional Logging Utils for CustomLoggers +Base class for Additional Logging Utils for CustomLoggers - Health Check for the logging util - Get Request / Response Payload for the logging util diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py index 6da7aec555..f9d4496c21 100644 --- a/litellm/integrations/custom_batch_logger.py +++ b/litellm/integrations/custom_batch_logger.py @@ -1,5 +1,5 @@ """ -Custom Logger that handles batching logic +Custom Logger that handles batching logic Use this if you want your logs to be stored in memory and flushed periodically. """ diff --git a/litellm/integrations/focus/transformer.py b/litellm/integrations/focus/transformer.py index 6f4433b4a0..b7d28e3dbb 100644 --- a/litellm/integrations/focus/transformer.py +++ b/litellm/integrations/focus/transformer.py @@ -9,6 +9,7 @@ import polars as pl from .schema import FOCUS_NORMALIZED_SCHEMA + _TAG_KEYS = ( "team_id", "team_alias", diff --git a/litellm/integrations/opik/utils.py b/litellm/integrations/opik/utils.py index 43577505c1..b0ab5991c9 100644 --- a/litellm/integrations/opik/utils.py +++ b/litellm/integrations/opik/utils.py @@ -105,7 +105,7 @@ def _remove_nulls(x: Dict[str, Any]) -> Dict[str, Any]: def get_traces_and_spans_from_payload( - payload: List[Dict[str, Any]], + payload: List[Dict[str, Any]] ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """ Separate traces and spans from payload. diff --git a/litellm/integrations/s3_v2.py b/litellm/integrations/s3_v2.py index 4ed8a809a1..332e84dd07 100644 --- a/litellm/integrations/s3_v2.py +++ b/litellm/integrations/s3_v2.py @@ -1,8 +1,8 @@ """ s3 Bucket Logging Integration -async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 -async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 +async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 +async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually """ diff --git a/litellm/interactions/__init__.py b/litellm/interactions/__init__.py index 0077ac8207..e1125b649a 100644 --- a/litellm/interactions/__init__.py +++ b/litellm/interactions/__init__.py @@ -5,28 +5,28 @@ This module provides SDK methods for Google's Interactions API. Usage: import litellm - + # Create an interaction with a model response = litellm.interactions.create( model="gemini-2.5-flash", input="Hello, how are you?" ) - + # Create an interaction with an agent response = litellm.interactions.create( agent="deep-research-pro-preview-12-2025", input="Research the current state of cancer research" ) - + # Async version response = await litellm.interactions.acreate(...) - + # Get an interaction response = litellm.interactions.get(interaction_id="...") - + # Delete an interaction result = litellm.interactions.delete(interaction_id="...") - + # Cancel an interaction result = litellm.interactions.cancel(interaction_id="...") diff --git a/litellm/interactions/main.py b/litellm/interactions/main.py index be9f2b99e0..ab429ef6db 100644 --- a/litellm/interactions/main.py +++ b/litellm/interactions/main.py @@ -8,25 +8,25 @@ Per OpenAPI spec (https://ai.google.dev/static/api/interactions.openapi.json): Usage: import litellm - + # Create an interaction with a model response = litellm.interactions.create( model="gemini-2.5-flash", input="Hello, how are you?" ) - + # Create an interaction with an agent response = litellm.interactions.create( agent="deep-research-pro-preview-12-2025", input="Research the current state of cancer research" ) - + # Async version response = await litellm.interactions.acreate(...) - + # Get an interaction response = litellm.interactions.get(interaction_id="...") - + # Delete an interaction result = litellm.interactions.delete(interaction_id="...") """ diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index cf20dbba70..a815442c2f 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -994,8 +994,10 @@ class Logging(LiteLLMLoggingBaseClass): try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - _metadata["raw_request"] = "redacted by litellm. \ + _metadata["raw_request"] = ( + "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" + ) else: curl_command = self._get_request_curl_command( api_base=additional_args.get("api_base", ""), @@ -1029,8 +1031,12 @@ class Logging(LiteLLMLoggingBaseClass): error=str(e), ) ) - _metadata["raw_request"] = "Unable to Log \ - raw request: {}".format(str(e)) + _metadata["raw_request"] = ( + "Unable to Log \ + raw request: {}".format( + str(e) + ) + ) if getattr(self, "logger_fn", None) and callable(self.logger_fn): try: self.logger_fn( diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index b96785c59f..abe9e016e2 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -5533,7 +5533,9 @@ def default_response_schema_prompt(response_schema: dict) -> str: prompt_str = """Use this JSON schema: ```json {} - ```""".format(response_schema) + ```""".format( + response_schema + ) return prompt_str diff --git a/litellm/litellm_core_utils/specialty_caches/dynamic_logging_cache.py b/litellm/litellm_core_utils/specialty_caches/dynamic_logging_cache.py index 0a6a4e82c7..13341f27a6 100644 --- a/litellm/litellm_core_utils/specialty_caches/dynamic_logging_cache.py +++ b/litellm/litellm_core_utils/specialty_caches/dynamic_logging_cache.py @@ -1,9 +1,9 @@ """ This is a cache for LangfuseLoggers. -Langfuse Python SDK initializes a thread for each client. +Langfuse Python SDK initializes a thread for each client. -This ensures we do +This ensures we do 1. Proper cleanup of Langfuse initialized clients. 2. Re-use created langfuse clients. """ diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/agentic_streaming_iterator.py b/litellm/llms/anthropic/experimental_pass_through/messages/agentic_streaming_iterator.py index d693d50b8e..d0780c82d0 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/agentic_streaming_iterator.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/agentic_streaming_iterator.py @@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, cast from litellm._logging import verbose_logger + # --------------------------------------------------------------------------- # SSE parsing helpers (module-level to keep the class lean) # --------------------------------------------------------------------------- diff --git a/litellm/llms/azure/chat/o_series_transformation.py b/litellm/llms/azure/chat/o_series_transformation.py index 0a73597a4e..cae7513245 100644 --- a/litellm/llms/azure/chat/o_series_transformation.py +++ b/litellm/llms/azure/chat/o_series_transformation.py @@ -4,10 +4,10 @@ Support for o1 and o3 model families https://platform.openai.com/docs/guides/reasoning Translations handled by LiteLLM: -- modalities: image => drop param (if user opts in to dropping param) -- role: system ==> translate to role 'user' -- streaming => faked by LiteLLM -- Tools, response_format => drop param (if user opts in to dropping param) +- modalities: image => drop param (if user opts in to dropping param) +- role: system ==> translate to role 'user' +- streaming => faked by LiteLLM +- Tools, response_format => drop param (if user opts in to dropping param) - Logprobs => drop param (if user opts in to dropping param) - Temperature => drop param (if user opts in to dropping param) """ diff --git a/litellm/llms/azure_ai/embed/cohere_transformation.py b/litellm/llms/azure_ai/embed/cohere_transformation.py index bbbfb60fbd..64433c21b6 100644 --- a/litellm/llms/azure_ai/embed/cohere_transformation.py +++ b/litellm/llms/azure_ai/embed/cohere_transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed. +Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed. Why separate file? Make it easy to see how transformation works diff --git a/litellm/llms/azure_ai/rerank/transformation.py b/litellm/llms/azure_ai/rerank/transformation.py index f64133afa8..b5993040ea 100644 --- a/litellm/llms/azure_ai/rerank/transformation.py +++ b/litellm/llms/azure_ai/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. +Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. """ from typing import Optional diff --git a/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py b/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py index 64a79b7327..2747551af8 100644 --- a/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py +++ b/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format. +Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format. Why separate file? Make it easy to see how transformation works diff --git a/litellm/llms/bedrock/embed/cohere_transformation.py b/litellm/llms/bedrock/embed/cohere_transformation.py index 885c91f975..d00cb74aae 100644 --- a/litellm/llms/bedrock/embed/cohere_transformation.py +++ b/litellm/llms/bedrock/embed/cohere_transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format. +Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format. Why separate file? Make it easy to see how transformation works """ diff --git a/litellm/llms/bedrock_mantle/chat/transformation.py b/litellm/llms/bedrock_mantle/chat/transformation.py index 81a56030a5..e413bb22b2 100644 --- a/litellm/llms/bedrock_mantle/chat/transformation.py +++ b/litellm/llms/bedrock_mantle/chat/transformation.py @@ -16,6 +16,7 @@ from litellm.secret_managers.main import get_secret_str from ...openai_like.chat.transformation import OpenAILikeChatConfig + BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1" diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index 81b6a1c7ae..3ab8baf7ba 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -1,5 +1,5 @@ """ -Legacy /v1/embedding handler for Bedrock Cohere. +Legacy /v1/embedding handler for Bedrock Cohere. """ import json diff --git a/litellm/llms/custom_httpx/mock_transport.py b/litellm/llms/custom_httpx/mock_transport.py index ad93cc134e..c9844753e0 100644 --- a/litellm/llms/custom_httpx/mock_transport.py +++ b/litellm/llms/custom_httpx/mock_transport.py @@ -13,6 +13,7 @@ from typing import Tuple import httpx + # --------------------------------------------------------------------------- # Pre-built response templates # --------------------------------------------------------------------------- diff --git a/litellm/llms/dashscope/cost_calculator.py b/litellm/llms/dashscope/cost_calculator.py index 8bb7f605b8..9b3e385116 100644 --- a/litellm/llms/dashscope/cost_calculator.py +++ b/litellm/llms/dashscope/cost_calculator.py @@ -1,5 +1,5 @@ """ -Cost calculator for Dashscope Chat models. +Cost calculator for Dashscope Chat models. Handles tiered pricing and prompt caching scenarios. """ diff --git a/litellm/llms/datarobot/chat/transformation.py b/litellm/llms/datarobot/chat/transformation.py index f81e242093..23ce63c25b 100644 --- a/litellm/llms/datarobot/chat/transformation.py +++ b/litellm/llms/datarobot/chat/transformation.py @@ -1,5 +1,5 @@ """ -Support for OpenAI's `/v1/chat/completions` endpoint. +Support for OpenAI's `/v1/chat/completions` endpoint. Calls done in OpenAI/openai.py as DataRobot is openai-compatible. """ diff --git a/litellm/llms/deepinfra/rerank/transformation.py b/litellm/llms/deepinfra/rerank/transformation.py index e4bfbcb251..276735f475 100644 --- a/litellm/llms/deepinfra/rerank/transformation.py +++ b/litellm/llms/deepinfra/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format. +Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format. """ from typing import Any, Dict, List, Optional, Union diff --git a/litellm/llms/deepseek/cost_calculator.py b/litellm/llms/deepseek/cost_calculator.py index e652ebeac5..0f4490cb3d 100644 --- a/litellm/llms/deepseek/cost_calculator.py +++ b/litellm/llms/deepseek/cost_calculator.py @@ -1,5 +1,5 @@ """ -Cost calculator for DeepSeek Chat models. +Cost calculator for DeepSeek Chat models. Handles prompt caching scenario. """ diff --git a/litellm/llms/elevenlabs/text_to_speech/transformation.py b/litellm/llms/elevenlabs/text_to_speech/transformation.py index 612fc687ef..6a59911701 100644 --- a/litellm/llms/elevenlabs/text_to_speech/transformation.py +++ b/litellm/llms/elevenlabs/text_to_speech/transformation.py @@ -22,6 +22,7 @@ from litellm.types.utils import all_litellm_params from ..common_utils import ElevenLabsException + if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.llms.openai import HttpxBinaryResponseContent diff --git a/litellm/llms/gemini/videos/transformation.py b/litellm/llms/gemini/videos/transformation.py index 9714c8a392..c7116940b2 100644 --- a/litellm/llms/gemini/videos/transformation.py +++ b/litellm/llms/gemini/videos/transformation.py @@ -55,7 +55,7 @@ def _convert_image_to_gemini_format(image_file) -> Dict[str, str]: def _usage_video_resolution_from_parameters( - parameters: Dict[str, Any], + parameters: Dict[str, Any] ) -> Optional[str]: """Normalize Veo ``parameters.resolution`` for usage and cost tracking.""" res = parameters.get("resolution") diff --git a/litellm/llms/infinity/rerank/transformation.py b/litellm/llms/infinity/rerank/transformation.py index b980460545..314bf2f8a3 100644 --- a/litellm/llms/infinity/rerank/transformation.py +++ b/litellm/llms/infinity/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format. +Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format. Why separate file? Make it easy to see how transformation works """ diff --git a/litellm/llms/jina_ai/rerank/transformation.py b/litellm/llms/jina_ai/rerank/transformation.py index 56be754fc3..ad4416925a 100644 --- a/litellm/llms/jina_ai/rerank/transformation.py +++ b/litellm/llms/jina_ai/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format. +Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format. Why separate file? Make it easy to see how transformation works diff --git a/litellm/llms/lm_studio/embed/transformation.py b/litellm/llms/lm_studio/embed/transformation.py index 87f4f6e73d..1285550c30 100644 --- a/litellm/llms/lm_studio/embed/transformation.py +++ b/litellm/llms/lm_studio/embed/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format. +Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format. Why separate file? Make it easy to see how transformation works diff --git a/litellm/llms/novita/chat/transformation.py b/litellm/llms/novita/chat/transformation.py index 5a64a124ad..c05d2d7b2c 100644 --- a/litellm/llms/novita/chat/transformation.py +++ b/litellm/llms/novita/chat/transformation.py @@ -1,5 +1,5 @@ """ -Support for OpenAI's `/v1/chat/completions` endpoint. +Support for OpenAI's `/v1/chat/completions` endpoint. Calls done in OpenAI/openai.py as Novita AI is openai-compatible. diff --git a/litellm/llms/nvidia_nim/chat/transformation.py b/litellm/llms/nvidia_nim/chat/transformation.py index 2ef92a9062..b8f8b04eb5 100644 --- a/litellm/llms/nvidia_nim/chat/transformation.py +++ b/litellm/llms/nvidia_nim/chat/transformation.py @@ -1,7 +1,7 @@ """ -Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer +Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer -This is OpenAI compatible +This is OpenAI compatible This file only contains param mapping logic diff --git a/litellm/llms/nvidia_nim/embed.py b/litellm/llms/nvidia_nim/embed.py index 61c8e8244e..24c6cc34e4 100644 --- a/litellm/llms/nvidia_nim/embed.py +++ b/litellm/llms/nvidia_nim/embed.py @@ -1,7 +1,7 @@ """ Nvidia NIM embeddings endpoint: https://docs.api.nvidia.com/nim/reference/nvidia-nv-embedqa-e5-v5-infer -This is OpenAI compatible +This is OpenAI compatible This file only contains param mapping logic diff --git a/litellm/llms/openai/chat/o_series_transformation.py b/litellm/llms/openai/chat/o_series_transformation.py index 8db7ecf7b3..02ae2cc975 100644 --- a/litellm/llms/openai/chat/o_series_transformation.py +++ b/litellm/llms/openai/chat/o_series_transformation.py @@ -1,14 +1,14 @@ """ -Support for o1/o3 model family +Support for o1/o3 model family https://platform.openai.com/docs/guides/reasoning Translations handled by LiteLLM: -- modalities: image => drop param (if user opts in to dropping param) -- role: system ==> translate to role 'user' -- streaming => faked by LiteLLM -- Tools, response_format => drop param (if user opts in to dropping param) -- Logprobs => drop param (if user opts in to dropping param) +- modalities: image => drop param (if user opts in to dropping param) +- role: system ==> translate to role 'user' +- streaming => faked by LiteLLM +- Tools, response_format => drop param (if user opts in to dropping param) +- Logprobs => drop param (if user opts in to dropping param) """ from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload diff --git a/litellm/llms/openai/common_utils.py b/litellm/llms/openai/common_utils.py index 381f215a13..c13a976c1b 100644 --- a/litellm/llms/openai/common_utils.py +++ b/litellm/llms/openai/common_utils.py @@ -201,7 +201,7 @@ class BaseOpenAILLM: @staticmethod def get_openai_client_initialization_param_fields( - client_type: Literal["openai", "azure"], + client_type: Literal["openai", "azure"] ) -> Tuple[str, ...]: """Returns a tuple of fields that are used to initialize the OpenAI client""" if client_type == "openai": diff --git a/litellm/llms/openrouter/image_generation/transformation.py b/litellm/llms/openrouter/image_generation/transformation.py index 9c2293eb3f..a55716a5e5 100644 --- a/litellm/llms/openrouter/image_generation/transformation.py +++ b/litellm/llms/openrouter/image_generation/transformation.py @@ -49,6 +49,7 @@ from litellm.types.utils import ( ) from litellm.llms.openrouter.common_utils import OpenRouterException + if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj else: diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index 8fd32bc446..3e4e2460cd 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -1,7 +1,7 @@ """ Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke` -In the Huggingface TGI format. +In the Huggingface TGI format. """ import json diff --git a/litellm/llms/sagemaker/embedding/transformation.py b/litellm/llms/sagemaker/embedding/transformation.py index 09bdb9295e..0443017118 100644 --- a/litellm/llms/sagemaker/embedding/transformation.py +++ b/litellm/llms/sagemaker/embedding/transformation.py @@ -1,7 +1,7 @@ """ Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke` -In the Huggingface TGI format. +In the Huggingface TGI format. """ from typing import TYPE_CHECKING, Any, List, Optional, Union diff --git a/litellm/llms/sap/credentials.py b/litellm/llms/sap/credentials.py index dd307ddf49..0ae351783e 100644 --- a/litellm/llms/sap/credentials.py +++ b/litellm/llms/sap/credentials.py @@ -207,7 +207,7 @@ def resolve_resource_group(sources: List[Source]) -> Optional[str]: def _parse_service_key_once( - service_key: Optional[Union[str, dict]], + service_key: Optional[Union[str, dict]] ) -> Optional[Dict[str, Any]]: """ Pre-parse service_key if it's a string to avoid repeated JSON parsing. diff --git a/litellm/llms/snowflake/chat/transformation.py b/litellm/llms/snowflake/chat/transformation.py index 23bb6f4475..3e590680a7 100644 --- a/litellm/llms/snowflake/chat/transformation.py +++ b/litellm/llms/snowflake/chat/transformation.py @@ -14,6 +14,7 @@ from ...openai_like.chat.transformation import OpenAIGPTConfig from ..utils import SnowflakeBaseConfig + if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj diff --git a/litellm/llms/together_ai/chat.py b/litellm/llms/together_ai/chat.py index 238849cc1e..7efb12fc1b 100644 --- a/litellm/llms/together_ai/chat.py +++ b/litellm/llms/together_ai/chat.py @@ -1,5 +1,5 @@ """ -Support for OpenAI's `/v1/chat/completions` endpoint. +Support for OpenAI's `/v1/chat/completions` endpoint. Calls done in OpenAI/openai.py as TogetherAI is openai-compatible. diff --git a/litellm/llms/together_ai/embed.py b/litellm/llms/together_ai/embed.py index 6a39b94acf..577df0256c 100644 --- a/litellm/llms/together_ai/embed.py +++ b/litellm/llms/together_ai/embed.py @@ -1,5 +1,5 @@ """ -Support for OpenAI's `/v1/embeddings` endpoint. +Support for OpenAI's `/v1/embeddings` endpoint. Calls done in OpenAI/openai.py as TogetherAI is openai-compatible. diff --git a/litellm/llms/together_ai/rerank/transformation.py b/litellm/llms/together_ai/rerank/transformation.py index f4d642bd25..63b593dfe4 100644 --- a/litellm/llms/together_ai/rerank/transformation.py +++ b/litellm/llms/together_ai/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format. +Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format. Why separate file? Make it easy to see how transformation works """ diff --git a/litellm/llms/vertex_ai/context_caching/transformation.py b/litellm/llms/vertex_ai/context_caching/transformation.py index ef71357d1d..950edbeb47 100644 --- a/litellm/llms/vertex_ai/context_caching/transformation.py +++ b/litellm/llms/vertex_ai/context_caching/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic for context caching. +Transformation logic for context caching. Why separate file? Make it easy to see how transformation works """ @@ -19,7 +19,7 @@ from ..gemini.transformation import ( def get_first_continuous_block_idx( - filtered_messages: List[Tuple[int, AllMessageValues]], # (idx, message) + filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message) ) -> int: """ Find the array index that ends the first continuous sequence of message blocks. diff --git a/litellm/llms/vertex_ai/gemini/transformation.py b/litellm/llms/vertex_ai/gemini/transformation.py index d20c26b8a2..9afa5dec46 100644 --- a/litellm/llms/vertex_ai/gemini/transformation.py +++ b/litellm/llms/vertex_ai/gemini/transformation.py @@ -632,14 +632,16 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915 contents.append(ContentType(role="user", parts=tool_call_responses)) if len(contents) == 0: - verbose_logger.warning(""" + verbose_logger.warning( + """ No contents in messages. Contents are required. See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body. If the original request did not comply to OpenAI API requirements it should have failed by now, but LiteLLM does not check for missing messages. Setting an empty content to prevent an 400 error. Relevant Issue - https://github.com/BerriAI/litellm/issues/9733 - """) + """ + ) contents.append(ContentType(role="user", parts=[PartType(text=" ")])) return contents except Exception as e: diff --git a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py index ba6e6f0c05..e1b365c9f4 100644 --- a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py +++ b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. +Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. Why separate file? Make it easy to see how transformation works """ diff --git a/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py b/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py index b835ad7d8f..9d9015c2b9 100644 --- a/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py +++ b/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py @@ -139,7 +139,7 @@ class VertexTextToSpeechAPI(VertexLLM): ########## End of logging ############ ####### Send the request ################### if _is_async is True: - return self.async_audio_speech( # type: ignore + return self.async_audio_speech( # type:ignore logging_obj=logging_obj, url=url, headers=headers, request=request ) sync_handler = _get_httpx_client() diff --git a/litellm/llms/vllm/completion/transformation.py b/litellm/llms/vllm/completion/transformation.py index e03b07f989..ec4c07e95d 100644 --- a/litellm/llms/vllm/completion/transformation.py +++ b/litellm/llms/vllm/completion/transformation.py @@ -1,5 +1,5 @@ """ -Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`. +Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`. NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead. """ diff --git a/litellm/llms/voyage/embedding/transformation_contextual.py b/litellm/llms/voyage/embedding/transformation_contextual.py index 1f5ca99f47..40328062e0 100644 --- a/litellm/llms/voyage/embedding/transformation_contextual.py +++ b/litellm/llms/voyage/embedding/transformation_contextual.py @@ -1,6 +1,6 @@ """ -This module is used to transform the request and response for the Voyage contextualized embeddings API. -This would be used for all the contextualized embeddings models in Voyage. +This module is used to transform the request and response for the Voyage contextualized embeddings API. +This would be used for all the contextualized embeddings models in Voyage. """ from typing import List, Optional, Union diff --git a/litellm/proxy/common_utils/custom_openapi_spec.py b/litellm/proxy/common_utils/custom_openapi_spec.py index fa3cb02195..a93749c395 100644 --- a/litellm/proxy/common_utils/custom_openapi_spec.py +++ b/litellm/proxy/common_utils/custom_openapi_spec.py @@ -324,7 +324,7 @@ class CustomOpenAPISpec: @staticmethod def add_chat_completion_request_schema( - openapi_schema: Dict[str, Any], + openapi_schema: Dict[str, Any] ) -> Dict[str, Any]: """ Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation. @@ -380,7 +380,7 @@ class CustomOpenAPISpec: @staticmethod def add_responses_api_request_schema( - openapi_schema: Dict[str, Any], + openapi_schema: Dict[str, Any] ) -> Dict[str, Any]: """ Add ResponsesAPIRequestParams schema to responses API endpoints for documentation. @@ -410,7 +410,7 @@ class CustomOpenAPISpec: @staticmethod def add_llm_api_request_schema_body( - openapi_schema: Dict[str, Any], + openapi_schema: Dict[str, Any] ) -> Dict[str, Any]: """ Add LLM API request schema bodies to OpenAPI specification for documentation. diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index 7b5a068edf..71abdfa5e9 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -257,7 +257,7 @@ async def get_form_data(request: Request) -> Dict[str, Any]: async def convert_upload_files_to_file_data( - form_data: Dict[str, Any], + form_data: Dict[str, Any] ) -> Dict[str, Any]: """ Convert FastAPI UploadFile objects to file data tuples for litellm. diff --git a/litellm/proxy/common_utils/openai_endpoint_utils.py b/litellm/proxy/common_utils/openai_endpoint_utils.py index 905967fa46..c4bfe11aec 100644 --- a/litellm/proxy/common_utils/openai_endpoint_utils.py +++ b/litellm/proxy/common_utils/openai_endpoint_utils.py @@ -1,5 +1,5 @@ """ -Contains utils used by OpenAI compatible endpoints +Contains utils used by OpenAI compatible endpoints """ from typing import Optional, Set diff --git a/litellm/proxy/config_management_endpoints/pass_through_endpoints.py b/litellm/proxy/config_management_endpoints/pass_through_endpoints.py index 4ebd989dc5..5ff02b8bce 100644 --- a/litellm/proxy/config_management_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/config_management_endpoints/pass_through_endpoints.py @@ -1,5 +1,5 @@ """ -What is this? +What is this? CRUD endpoints for managing pass-through endpoints """ diff --git a/litellm/proxy/db/create_views.py b/litellm/proxy/db/create_views.py index 97525a528d..d84cebcf05 100644 --- a/litellm/proxy/db/create_views.py +++ b/litellm/proxy/db/create_views.py @@ -34,7 +34,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915 if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): raise # If an error occurs, the view does not exist, so create it - await db.execute_raw(""" + await db.execute_raw( + """ CREATE VIEW "LiteLLM_VerificationTokenView" AS SELECT v.*, @@ -46,7 +47,8 @@ async def create_missing_views(db: _db): # noqa: PLR0915 FROM "LiteLLM_VerificationToken" v LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id; - """) + """ + ) verbose_logger.debug("LiteLLM_VerificationTokenView Created!") diff --git a/litellm/proxy/guardrails/_content_utils.py b/litellm/proxy/guardrails/_content_utils.py index 766ef0cf9f..7cad1352a7 100644 --- a/litellm/proxy/guardrails/_content_utils.py +++ b/litellm/proxy/guardrails/_content_utils.py @@ -10,6 +10,7 @@ every text fragment. from typing import Any, Callable, Dict, FrozenSet, Iterator, List + # Call types whose body carries free-form chat / prompt text that # text-content guardrails (banned keywords, content moderation, secret # detection, …) should inspect. The proxy ingress passes ``route_type`` diff --git a/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py index 1e3dd906b9..c4aaea709b 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py +++ b/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py @@ -4,6 +4,7 @@ from litellm.types.guardrails import SupportedGuardrailIntegrations from .akto import AktoGuardrail + if TYPE_CHECKING: from litellm.types.guardrails import Guardrail, LitellmParams diff --git a/litellm/proxy/hooks/litellm_skills/__init__.py b/litellm/proxy/hooks/litellm_skills/__init__.py index 1507b652ab..057cf3d8b3 100644 --- a/litellm/proxy/hooks/litellm_skills/__init__.py +++ b/litellm/proxy/hooks/litellm_skills/__init__.py @@ -6,7 +6,7 @@ The actual skill logic is in litellm/llms/litellm_proxy/skills/. Usage: from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook - + # Register hook in proxy litellm.callbacks.append(SkillsInjectionHook()) """ diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 20735b3552..81b133e6c8 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -1,9 +1,9 @@ """ BUDGET MANAGEMENT -All /budget management endpoints +All /budget management endpoints -/budget/new +/budget/new /budget/info /budget/update /budget/delete diff --git a/litellm/proxy/management_endpoints/customer_endpoints.py b/litellm/proxy/management_endpoints/customer_endpoints.py index 1fd8320db2..4889f0b7f8 100644 --- a/litellm/proxy/management_endpoints/customer_endpoints.py +++ b/litellm/proxy/management_endpoints/customer_endpoints.py @@ -1,9 +1,9 @@ """ CUSTOMER MANAGEMENT -All /customer management endpoints +All /customer management endpoints -/customer/new +/customer/new /customer/info /customer/update /customer/delete diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index b14d58e833..af84bc123f 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -529,7 +529,7 @@ async def _update_existing_team_model_assignment( """ def _get_team_public_model_name( - model_info: Optional[Union[dict, str]], + model_info: Optional[Union[dict, str]] ) -> Optional[str]: if isinstance(model_info, dict): value = model_info.get("team_public_model_name") diff --git a/litellm/proxy/management_endpoints/sso/custom_microsoft_sso.py b/litellm/proxy/management_endpoints/sso/custom_microsoft_sso.py index 04e44c623d..191212d6f0 100644 --- a/litellm/proxy/management_endpoints/sso/custom_microsoft_sso.py +++ b/litellm/proxy/management_endpoints/sso/custom_microsoft_sso.py @@ -7,7 +7,7 @@ variables. Environment Variables: - MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL -- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL +- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL - MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL If these are not set, the default Microsoft endpoints are used. diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 564ea85200..259624f1e1 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -4347,7 +4347,9 @@ async def list_team( except Exception as e: team_exception = """Invalid team object for team_id: {}. team_object={}. Error: {} - """.format(team.team_id, team.model_dump(), str(e)) + """.format( + team.team_id, team.model_dump(), str(e) + ) verbose_proxy_logger.exception(team_exception) continue # Sort the responses by team_alias diff --git a/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py b/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py index ebd276fbee..872b6fa225 100644 --- a/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py +++ b/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py @@ -3,7 +3,7 @@ User Agent Analytics Endpoints This module provides optimized endpoints for tracking user agent activity metrics including: - Daily Active Users (DAU) by tags for configurable number of days -- Weekly Active Users (WAU) by tags for configurable number of weeks +- Weekly Active Users (WAU) by tags for configurable number of weeks - Monthly Active Users (MAU) by tags for configurable number of months - Summary analytics by tags diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cursor_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cursor_passthrough_logging_handler.py index e7696e5a18..a104f96263 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cursor_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cursor_passthrough_logging_handler.py @@ -18,6 +18,7 @@ from litellm.litellm_core_utils.litellm_logging import ( from litellm.proxy._types import PassThroughEndpointLoggingTypedDict from litellm.types.utils import StandardPassThroughResponseObject + CURSOR_AGENT_ENDPOINTS: Dict[str, str] = { "POST /v0/agents": "cursor:agent:create", "GET /v0/agents": "cursor:agent:list", diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index d7e04aea70..90bbfdd25a 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -292,7 +292,9 @@ class ProxyInitializationHelpers: _endpoint_str = ( f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\" ) - curl_command = _endpoint_str + """ + curl_command = ( + _endpoint_str + + """ --header 'Content-Type: application/json' \\ --data ' { "model": "gpt-3.5-turbo", @@ -305,6 +307,7 @@ class ProxyInitializationHelpers: }' \n """ + ) print() # noqa print( # noqa '\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n' @@ -380,9 +383,11 @@ class ProxyInitializationHelpers: with open(os.devnull, "w") as devnull: subprocess.Popen(command, stdout=devnull, stderr=devnull) except Exception as e: - print(f""" + print( # noqa + f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """) # noqa # noqa + """ + ) # noqa @staticmethod def _is_port_in_use(port): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0614db4fb1..c96d0acb00 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2688,9 +2688,11 @@ def run_ollama_serve(): with open(os.devnull, "w") as devnull: subprocess.Popen(command, stdout=devnull, stderr=devnull) except Exception as e: - verbose_proxy_logger.debug(f""" + verbose_proxy_logger.debug( + f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """) + """ + ) def _get_process_rss_mb() -> Optional[float]: diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index e3019801aa..d030fabe8b 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -3184,14 +3184,16 @@ async def provider_budgets() -> ProviderBudgetResponse: async def get_spend_by_tags( prisma_client: PrismaClient, start_date=None, end_date=None ): - response = await prisma_client.db.query_raw(""" + response = await prisma_client.db.query_raw( + """ SELECT jsonb_array_elements_text(request_tags) AS individual_request_tag, COUNT(*) AS log_count, SUM(spend) AS total_spend FROM "LiteLLM_SpendLogs" GROUP BY individual_request_tag; - """) + """ + ) return response diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e0a956c166..a52dc8e55f 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -2712,7 +2712,8 @@ class PrismaClient: required_view = "LiteLLM_VerificationTokenView" expected_views_str = ", ".join(f"'{view}'" for view in expected_views) pg_schema = os.getenv("DATABASE_SCHEMA", "public") - ret = await self.db.query_raw(f""" + ret = await self.db.query_raw( + f""" WITH existing_views AS ( SELECT viewname FROM pg_views @@ -2724,7 +2725,8 @@ class PrismaClient: (SELECT COUNT(*) FROM existing_views) AS view_count, ARRAY_AGG(viewname) AS view_names FROM existing_views - """) + """ + ) expected_total_views = len(expected_views) if ret[0]["view_count"] == expected_total_views: verbose_proxy_logger.info("All necessary views exist!") @@ -2733,7 +2735,8 @@ class PrismaClient: ## check if required view exists ## if ret[0]["view_names"] and required_view not in ret[0]["view_names"]: await self.health_check() # make sure we can connect to db - await self.db.execute_raw(""" + await self.db.execute_raw( + """ CREATE VIEW "LiteLLM_VerificationTokenView" AS SELECT v.*, @@ -2743,7 +2746,8 @@ class PrismaClient: t.rpm_limit AS team_rpm_limit FROM "LiteLLM_VerificationToken" v LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; - """) + """ + ) verbose_proxy_logger.info( "LiteLLM_VerificationTokenView Created in DB!" diff --git a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py index b47f6a747d..8ce1bedcf9 100644 --- a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py @@ -1,5 +1,5 @@ """ -What is this? +What is this? Logging Pass-Through Endpoints """ diff --git a/litellm/router.py b/litellm/router.py index c3537ffd64..7512ee387d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -826,7 +826,7 @@ class Router: @staticmethod def _normalize_strategy( - strategy: Union[RoutingStrategy, str, None], + strategy: Union[RoutingStrategy, str, None] ) -> Optional[str]: if strategy is None: return None diff --git a/litellm/router_strategy/adaptive_router/hooks.py b/litellm/router_strategy/adaptive_router/hooks.py index 99fe5e26f7..9e346006ac 100644 --- a/litellm/router_strategy/adaptive_router/hooks.py +++ b/litellm/router_strategy/adaptive_router/hooks.py @@ -103,7 +103,7 @@ def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str def _recent_tool_results( - messages: Optional[List[Dict[str, Any]]], + messages: Optional[List[Dict[str, Any]]] ) -> List[Dict[str, Any]]: """Extract the current turn's tool result payloads from the request messages. diff --git a/litellm/router_strategy/adaptive_router/signals.py b/litellm/router_strategy/adaptive_router/signals.py index 5e33a64d27..a48bdea1eb 100644 --- a/litellm/router_strategy/adaptive_router/signals.py +++ b/litellm/router_strategy/adaptive_router/signals.py @@ -24,6 +24,7 @@ from litellm.router_strategy.adaptive_router.config import ( TOOL_CALL_HISTORY_MAX, ) + # ---- Public types --------------------------------------------------------- diff --git a/litellm/router_strategy/budget_limiter.py b/litellm/router_strategy/budget_limiter.py index da41577e99..be27b85247 100644 --- a/litellm/router_strategy/budget_limiter.py +++ b/litellm/router_strategy/budget_limiter.py @@ -10,11 +10,11 @@ This means you can use this with weighted-pick, lowest-latency, simple-shuffle, Example: ``` openai: - budget_limit: 0.000000000001 - time_period: 1d + budget_limit: 0.000000000001 + time_period: 1d anthropic: - budget_limit: 100 - time_period: 7d + budget_limit: 100 + time_period: 7d ``` """ diff --git a/litellm/router_utils/get_retry_from_policy.py b/litellm/router_utils/get_retry_from_policy.py index 162d6428f8..ec326ebb50 100644 --- a/litellm/router_utils/get_retry_from_policy.py +++ b/litellm/router_utils/get_retry_from_policy.py @@ -1,5 +1,5 @@ """ -Get num retries for an exception. +Get num retries for an exception. - Account for retry policy by exception type. """ diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py index 48f85a8341..17b453d603 100644 --- a/litellm/router_utils/pattern_match_deployments.py +++ b/litellm/router_utils/pattern_match_deployments.py @@ -34,7 +34,7 @@ class PatternUtils: @staticmethod def sorted_patterns( - patterns: Dict[str, List[Dict]], + patterns: Dict[str, List[Dict]] ) -> List[Tuple[str, List[Dict]]]: """ Cached property for patterns sorted by specificity. diff --git a/litellm/router_utils/router_callbacks/track_deployment_metrics.py b/litellm/router_utils/router_callbacks/track_deployment_metrics.py index 9039b0df8e..1f226879d0 100644 --- a/litellm/router_utils/router_callbacks/track_deployment_metrics.py +++ b/litellm/router_utils/router_callbacks/track_deployment_metrics.py @@ -1,5 +1,5 @@ """ -Helper functions to get/set num success and num failures per deployment +Helper functions to get/set num success and num failures per deployment set_deployment_failures_for_current_minute diff --git a/litellm/secret_managers/aws_secret_manager.py b/litellm/secret_managers/aws_secret_manager.py index 60d0a713ef..fbe951e649 100644 --- a/litellm/secret_managers/aws_secret_manager.py +++ b/litellm/secret_managers/aws_secret_manager.py @@ -4,7 +4,7 @@ This is a file for the AWS Secret Manager Integration Relevant issue: https://github.com/BerriAI/litellm/issues/1883 Requires: -* `os.environ["AWS_REGION_NAME"], +* `os.environ["AWS_REGION_NAME"], * `pip install boto3>=1.28.57` """ diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py index 4461e34396..c1b4d019dc 100644 --- a/litellm/secret_managers/aws_secret_manager_v2.py +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -10,7 +10,7 @@ Handles Async Operations for: Relevant issue: https://github.com/BerriAI/litellm/issues/1883 Requires: -* `os.environ["AWS_REGION_NAME"], +* `os.environ["AWS_REGION_NAME"], * `pip install boto3>=1.28.57` """ diff --git a/litellm/vector_store_files/utils.py b/litellm/vector_store_files/utils.py index 1ee5b47e30..ffe73516bd 100644 --- a/litellm/vector_store_files/utils.py +++ b/litellm/vector_store_files/utils.py @@ -21,7 +21,7 @@ class VectorStoreFileRequestUtils: @staticmethod def get_create_request_params( - params: Dict[str, Any], + params: Dict[str, Any] ) -> VectorStoreFileCreateRequest: filtered = VectorStoreFileRequestUtils._filter_params( params=params, model=VectorStoreFileCreateRequest @@ -37,7 +37,7 @@ class VectorStoreFileRequestUtils: @staticmethod def get_update_request_params( - params: Dict[str, Any], + params: Dict[str, Any] ) -> VectorStoreFileUpdateRequest: filtered = VectorStoreFileRequestUtils._filter_params( params=params, model=VectorStoreFileUpdateRequest diff --git a/pyproject.toml b/pyproject.toml index 9b48c841e9..7ff388f184 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,11 +76,7 @@ utils = [ # Not in Docker or PyPI proxy extra. "numpydoc==1.8.0", ] -# diskcache intentionally unpinned: CVE-2025-69872 (pickle RCE) has no -# upstream fix. Stub kept so `pip install litellm[caching]` doesn't warn; -# DiskCache loads diskcache lazily and forces JSONDisk for safety. See -# litellm/caching/disk_cache.py. -caching = [] +caching = ["diskcache==5.6.3"] semantic-router = [ "semantic-router==0.1.12; python_version < '3.14'", "aurelio-sdk==0.0.19; python_version < '3.14'", @@ -130,7 +126,7 @@ litellm-proxy = "litellm.proxy.client.cli:cli" dev = [ "diff-cover==9.7.2", "flake8==7.3.0", - "black==26.3.1", + "black==24.10.0", "mypy==1.19.0", "pytest==9.0.3", "pytest-mock==3.15.1", diff --git a/scripts/adaptive_router_demo/eval.py b/scripts/adaptive_router_demo/eval.py index 66b91c7193..b02e4a37d3 100644 --- a/scripts/adaptive_router_demo/eval.py +++ b/scripts/adaptive_router_demo/eval.py @@ -35,7 +35,7 @@ import httpx class EvalCase: category: str prompt: str - ideal: str # criteria the judge checks the response against + ideal: str # criteria the judge checks the response against EVAL_CASES: List[EvalCase] = [ @@ -177,19 +177,14 @@ async def evaluate( async with httpx.AsyncClient() as client: for i, case in enumerate(EVAL_CASES, 1): print(f"\n[{i}/{len(EVAL_CASES)}] category={case.category}") - print( - f" prompt : {case.prompt[:80]}{'…' if len(case.prompt) > 80 else ''}" - ) + print(f" prompt : {case.prompt[:80]}{'…' if len(case.prompt) > 80 else ''}") session_id = f"eval-{uuid.uuid4()}" # Round 1: single-turn real request — get the actual LLM response to judge. try: response, chosen = await _chat( - client, - proxy_url, - api_key, - router, + client, proxy_url, api_key, router, [{"role": "user", "content": case.prompt}], session_id=session_id, ) @@ -199,25 +194,16 @@ async def evaluate( continue print(f" model : {chosen or router}") - print( - f" response : {response[:120].replace(chr(10), ' ')}{'…' if len(response) > 120 else ''}" - ) + print(f" response : {response[:120].replace(chr(10), ' ')}{'…' if len(response) > 120 else ''}") # Judge the real response. judge_msgs = [ {"role": "system", "content": JUDGE_SYSTEM}, - { - "role": "user", - "content": _judge_user(case.prompt, case.ideal, response), - }, + {"role": "user", "content": _judge_user(case.prompt, case.ideal, response)}, ] try: verdict, _ = await _chat( - client, - proxy_url, - api_key, - judge_model, - judge_msgs, + client, proxy_url, api_key, judge_model, judge_msgs, ) except Exception as exc: # noqa: BLE001 print(f" ERROR calling judge: {exc}", file=sys.stderr) @@ -241,19 +227,15 @@ async def evaluate( # On PASS → satisfaction follow-up (+alpha). On FAIL → neutral (no signal). follow_up = SATISFY_FOLLOWUP if is_pass else NEUTRAL_FOLLOWUP bandit_msgs = [ - {"role": "user", "content": case.prompt}, + {"role": "user", "content": case.prompt}, {"role": "assistant", "content": response}, - {"role": "user", "content": "ok continue"}, + {"role": "user", "content": "ok continue"}, {"role": "assistant", "content": FAB_ASSISTANT}, - {"role": "user", "content": follow_up}, + {"role": "user", "content": follow_up}, ] try: await _chat( - client, - proxy_url, - api_key, - router, - bandit_msgs, + client, proxy_url, api_key, router, bandit_msgs, session_id=session_id, ) except Exception as exc: # noqa: BLE001 @@ -275,17 +257,11 @@ async def evaluate( # Entry point # --------------------------------------------------------------------------- def main() -> None: - ap = argparse.ArgumentParser( - description="Evaluate the adaptive router with LLM-as-judge." - ) - ap.add_argument("--proxy-url", default="http://localhost:4000") - ap.add_argument("--api-key", required=True, help="proxy API key") - ap.add_argument( - "--router", default="smart-cheap-router", help="adaptive router model name" - ) - ap.add_argument( - "--judge-model", default="smart", help="model name for the judge (via proxy)" - ) + ap = argparse.ArgumentParser(description="Evaluate the adaptive router with LLM-as-judge.") + ap.add_argument("--proxy-url", default="http://localhost:4000") + ap.add_argument("--api-key", required=True, help="proxy API key") + ap.add_argument("--router", default="smart-cheap-router", help="adaptive router model name") + ap.add_argument("--judge-model", default="smart", help="model name for the judge (via proxy)") args = ap.parse_args() asyncio.run(evaluate(args.proxy_url, args.api_key, args.router, args.judge_model)) diff --git a/scripts/adaptive_router_demo/traffic.py b/scripts/adaptive_router_demo/traffic.py index dd64d95c90..eae5506eae 100644 --- a/scripts/adaptive_router_demo/traffic.py +++ b/scripts/adaptive_router_demo/traffic.py @@ -72,8 +72,8 @@ PROMPTS: Dict[str, List[str]] = { # so that signals attribute to the right (type, model) bandit cell. SATISFY: Dict[str, str] = { "code_generation": "thanks, that works! now write me a python function that does the inverse", - "factual_lookup": "perfect, thanks! who is the current prime minister?", - "writing": "great, thanks! now write a follow-up email confirming attendance", + "factual_lookup": "perfect, thanks! who is the current prime minister?", + "writing": "great, thanks! now write a follow-up email confirming attendance", } # Neutral follow-up — does not match any signal regex, does not move the bandit. @@ -83,8 +83,8 @@ NEUTRAL_FOLLOWUP = "ok, noted" # Defaults: smart dominates code/writing; both are fine for factual_lookup. ORACLE: Dict[str, Dict[str, float]] = { "code_generation": {"smart": 0.92, "fast": 0.35}, - "factual_lookup": {"smart": 0.90, "fast": 0.85}, - "writing": {"smart": 0.85, "fast": 0.55}, + "factual_lookup": {"smart": 0.90, "fast": 0.85}, + "writing": {"smart": 0.85, "fast": 0.55}, } # Fabricated assistant turn — content doesn't matter for the hook, only the role. @@ -94,11 +94,11 @@ FAB_ASSISTANT = "Got it. Working on that now." def _build_messages(prompt: str, last_user: str) -> List[Dict[str, str]]: """5-message conversation that passes the SIGNAL_GATE_MIN_MESSAGES=4 gate.""" return [ - {"role": "user", "content": prompt}, + {"role": "user", "content": prompt}, {"role": "assistant", "content": FAB_ASSISTANT}, - {"role": "user", "content": "ok continue"}, + {"role": "user", "content": "ok continue"}, {"role": "assistant", "content": FAB_ASSISTANT}, - {"role": "user", "content": last_user}, + {"role": "user", "content": last_user}, ] @@ -155,11 +155,7 @@ async def _drive_one_session( # # Round 1: neutral follow-up → no signal fires, but we learn the pick. ok, chosen = await _send( - client, - proxy_url, - api_key, - router, - session_id, + client, proxy_url, api_key, router, session_id, _build_messages(prompt, NEUTRAL_FOLLOWUP), mock_response=FAB_ASSISTANT, ) @@ -175,15 +171,10 @@ async def _drive_one_session( # follow-up matches satisfaction → +alpha for (request_type, chosen). history = _build_messages(prompt, NEUTRAL_FOLLOWUP) + [ {"role": "assistant", "content": FAB_ASSISTANT}, - {"role": "user", "content": follow_up}, + {"role": "user", "content": follow_up}, ] await _send( - client, - proxy_url, - api_key, - router, - session_id, - history, + client, proxy_url, api_key, router, session_id, history, mock_response=FAB_ASSISTANT, ) return chosen @@ -192,22 +183,13 @@ async def _drive_one_session( async def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--proxy-url", default="http://localhost:4000") - ap.add_argument( - "--api-key", required=True, help="proxy key with /v1/chat/completions perms" - ) - ap.add_argument("--router", default="smart-cheap-router") - ap.add_argument("--rounds", type=int, default=100) - ap.add_argument( - "--rate", - type=float, - default=0.5, - help="seconds between sessions; lower = faster", - ) - ap.add_argument( - "--types", - default="code_generation,factual_lookup,writing", - help="comma-separated subset of request types to drive", - ) + ap.add_argument("--api-key", required=True, help="proxy key with /v1/chat/completions perms") + ap.add_argument("--router", default="smart-cheap-router") + ap.add_argument("--rounds", type=int, default=100) + ap.add_argument("--rate", type=float, default=0.5, + help="seconds between sessions; lower = faster") + ap.add_argument("--types", default="code_generation,factual_lookup,writing", + help="comma-separated subset of request types to drive") args = ap.parse_args() types = [t.strip() for t in args.types.split(",") if t.strip() in PROMPTS] @@ -225,12 +207,7 @@ async def main() -> None: rt = random.choice(types) prompt = random.choice(PROMPTS[rt]) chosen = await _drive_one_session( - client, - args.proxy_url, - args.api_key, - args.router, - rt, - prompt, + client, args.proxy_url, args.api_key, args.router, rt, prompt, ) if chosen: counts[(rt, chosen)] = counts.get((rt, chosen), 0) + 1 diff --git a/scripts/benchmark_mock.py b/scripts/benchmark_mock.py index 6c3a6ad894..55dbb1d413 100644 --- a/scripts/benchmark_mock.py +++ b/scripts/benchmark_mock.py @@ -8,6 +8,7 @@ import statistics import aiohttp + REQUEST_BODY = { "model": "db-openai-endpoint", "messages": [{"role": "user", "content": "hi"}], diff --git a/scripts/benchmark_proxy_vs_provider.py b/scripts/benchmark_proxy_vs_provider.py index 8e1289e3c2..6196580b23 100755 --- a/scripts/benchmark_proxy_vs_provider.py +++ b/scripts/benchmark_proxy_vs_provider.py @@ -11,7 +11,7 @@ USAGE EXAMPLES: export PROVIDER_URL='https://api.openai.com/v1/chat/completions' export LITELLM_PROXY_API_KEY='sk-1234' export PROVIDER_API_KEY='sk-openai-key' - + # Run from scripts directory cd scripts python benchmark_proxy_vs_provider.py diff --git a/scripts/eval_compression.py b/scripts/eval_compression.py index 44265eb151..a169cc02d7 100644 --- a/scripts/eval_compression.py +++ b/scripts/eval_compression.py @@ -42,7 +42,8 @@ from litellm.types.utils import CallTypes PROBLEMS = [ { "id": "has_close_elements", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def has_close_elements(numbers: List[float], threshold: float) -> bool: @@ -53,8 +54,10 @@ PROBLEMS = [ >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) True \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True @@ -62,11 +65,13 @@ PROBLEMS = [ assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0], 2.0) == True assert has_close_elements([], 0.5) == False print("PASSED") - """), + """ + ), }, { "id": "separate_paren_groups", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def separate_paren_groups(paren_string: str) -> List[str]: @@ -77,18 +82,22 @@ PROBLEMS = [ >>> separate_paren_groups('( ) (( )) (( )( ))') ['()', '(())', '(()())'] \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert separate_paren_groups('(()()) ((())) () ((())()())') == ['(()())', '((()))', '()', '((())()())'] assert separate_paren_groups('() (()) ((())) (((())))') == ['()', '(())', '((()))', '(((())))'] assert separate_paren_groups('(()(()))') == ['(()(()))'] assert separate_paren_groups('( ) (( )) (( )( ))') == ['()', '(())', '(()())'] print("PASSED") - """), + """ + ), }, { "id": "truncate_number", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ def truncate_number(number: float) -> float: \"\"\"Given a positive floating point number, it can be decomposed into an integer part (largest integer smaller than given number) and decimals @@ -97,17 +106,21 @@ PROBLEMS = [ >>> truncate_number(3.5) 0.5 \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert truncate_number(3.5) == 0.5 assert abs(truncate_number(1.33) - 0.33) < 1e-6 assert abs(truncate_number(123.456) - 0.456) < 1e-6 print("PASSED") - """), + """ + ), }, { "id": "below_zero", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def below_zero(operations: List[int]) -> bool: @@ -119,8 +132,10 @@ PROBLEMS = [ >>> below_zero([1, 2, -4, 5]) True \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert below_zero([]) == False assert below_zero([1, 2, -3, 1, 2, -3]) == False assert below_zero([1, 2, -4, 5, 6]) == True @@ -128,11 +143,13 @@ PROBLEMS = [ assert below_zero([1, -1, 2, -2, 5, -5, 4, -5]) == True assert below_zero([1, -2]) == True print("PASSED") - """), + """ + ), }, { "id": "mean_absolute_deviation", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def mean_absolute_deviation(numbers: List[float]) -> float: @@ -144,17 +161,21 @@ PROBLEMS = [ >>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0]) 1.0 \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0]) - 1.0) < 1e-6 assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0, 5.0]) - 1.2) < 1e-6 assert abs(mean_absolute_deviation([1.0, 1.0, 1.0, 1.0]) - 0.0) < 1e-6 print("PASSED") - """), + """ + ), }, { "id": "intersperse", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def intersperse(numbers: List[int], delimiter: int) -> List[int]: @@ -164,17 +185,21 @@ PROBLEMS = [ >>> intersperse([1, 2, 3], 4) [1, 4, 2, 4, 3] \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert intersperse([], 7) == [] assert intersperse([5, 6, 3, 2], 8) == [5, 8, 6, 8, 3, 8, 2] assert intersperse([2, 2, 2], 2) == [2, 2, 2, 2, 2] print("PASSED") - """), + """ + ), }, { "id": "parse_nested_parens", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def parse_nested_parens(paren_string: str) -> List[int]: @@ -184,17 +209,21 @@ PROBLEMS = [ >>> parse_nested_parens('(()()) ((())) () ((())())') [2, 3, 1, 3] \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert parse_nested_parens('(()()) ((())) () ((())())') == [2, 3, 1, 3] assert parse_nested_parens('() (()) ((())) (((())))') == [1, 2, 3, 4] assert parse_nested_parens('(()(())((())))') == [4] print("PASSED") - """), + """ + ), }, { "id": "filter_by_substring", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def filter_by_substring(strings: List[str], substring: str) -> List[str]: @@ -204,18 +233,22 @@ PROBLEMS = [ >>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a') ['abc', 'bacd', 'array'] \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert filter_by_substring([], 'john') == [] assert filter_by_substring(['xxx', 'asd', 'xxy', 'john doe', 'xxxuj', 'xxx'], 'xxx') == ['xxx', 'xxxuj', 'xxx'] assert filter_by_substring(['xxx', 'asd', 'aaber', 'john doe', 'xxxuj', 'xxx'], 'xx') == ['xxx', 'xxxuj', 'xxx'] assert filter_by_substring(['grunt', 'hierarchial', 'abc', 'hierarchial'], 'hi') == ['hierarchial', 'hierarchial'] print("PASSED") - """), + """ + ), }, { "id": "sum_product", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List, Tuple def sum_product(numbers: List[int]) -> Tuple[int, int]: @@ -226,19 +259,23 @@ PROBLEMS = [ >>> sum_product([1, 2, 3, 4]) (10, 24) \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert sum_product([]) == (0, 1) assert sum_product([1, 1, 1]) == (3, 1) assert sum_product([100, 0]) == (100, 0) assert sum_product([3, 5, 7]) == (15, 105) assert sum_product([10]) == (10, 10) print("PASSED") - """), + """ + ), }, { "id": "max_element", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def max_element(l: List[int]) -> int: @@ -248,17 +285,21 @@ PROBLEMS = [ >>> max_element([5, 3, -5, 2, -3, 3, 9, 0, 123, 1, -10]) 123 \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert max_element([1, 2, 3]) == 3 assert max_element([5, 3, -5, 2, -3, 3, 9, 0, 124, 1, -10]) == 124 assert max_element([-1, -2, -3]) == -1 print("PASSED") - """), + """ + ), }, { "id": "fizz_buzz", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ def fizz_buzz(n: int) -> int: \"\"\"Return the number of times the digit 7 appears in integers less than n which are divisible by 11 or 13. >>> fizz_buzz(50) @@ -268,8 +309,10 @@ PROBLEMS = [ >>> fizz_buzz(79) 3 \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert fizz_buzz(50) == 0 assert fizz_buzz(78) == 2 assert fizz_buzz(79) == 3 @@ -277,11 +320,13 @@ PROBLEMS = [ assert fizz_buzz(200) == 6 assert fizz_buzz(4000) == 192 print("PASSED") - """), + """ + ), }, { "id": "sort_by_binary_len", - "prompt": textwrap.dedent("""\ + "prompt": textwrap.dedent( + """\ from typing import List def sort_array(arr: List[int]) -> List[int]: @@ -294,8 +339,10 @@ PROBLEMS = [ >>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 4, 3] \"\"\" - """), - "tests": textwrap.dedent("""\ + """ + ), + "tests": textwrap.dedent( + """\ assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 4, 3, 5] assert sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2] assert sort_array([1, 0, 2, 3, 4]) == [0, 1, 2, 4, 3] @@ -303,7 +350,8 @@ PROBLEMS = [ assert sort_array([2, 5, 77, 4, 5, 3, 5, 7, 2, 3, 4]) == [2, 2, 4, 4, 3, 3, 5, 5, 5, 7, 77] assert sort_array([3, 6, 44, 12, 32, 5]) == [32, 3, 5, 6, 12, 44] print("PASSED") - """), + """ + ), }, ] @@ -312,7 +360,8 @@ PROBLEMS = [ # compressor to identify and drop them. DISTRACTOR_SNIPPETS = [ # distractor 0 — database connection pool - textwrap.dedent("""\ + textwrap.dedent( + """\ # db_pool.py import threading from contextlib import contextmanager @@ -359,9 +408,11 @@ DISTRACTOR_SNIPPETS = [ for conn in self._pool: conn.close() self._pool.clear() - """), + """ + ), # distractor 1 — HTTP retry logic - textwrap.dedent("""\ + textwrap.dedent( + """\ # http_retry.py import time import random @@ -405,9 +456,11 @@ DISTRACTOR_SNIPPETS = [ resp = requests.get(url, params=params, timeout=30) resp.raise_for_status() return resp.json() - """), + """ + ), # distractor 2 — LRU cache implementation - textwrap.dedent("""\ + textwrap.dedent( + """\ # lru_cache.py from collections import OrderedDict from threading import RLock @@ -456,9 +509,11 @@ DISTRACTOR_SNIPPETS = [ def __contains__(self, key): return key in self._cache - """), + """ + ), # distractor 3 — CSV report generator - textwrap.dedent("""\ + textwrap.dedent( + """\ # report_gen.py import csv import io @@ -511,9 +566,11 @@ DISTRACTOR_SNIPPETS = [ except (ValueError, KeyError): return False return self.filter_rows(in_range) - """), + """ + ), # distractor 4 — async task queue - textwrap.dedent("""\ + textwrap.dedent( + """\ # task_queue.py import asyncio import logging @@ -583,9 +640,11 @@ DISTRACTOR_SNIPPETS = [ async def shutdown(self): for w in self._workers: w.cancel() - """), + """ + ), # distractor 5 — config parser with env var interpolation - textwrap.dedent("""\ + textwrap.dedent( + """\ # config_parser.py import os import re @@ -648,7 +707,8 @@ DISTRACTOR_SNIPPETS = [ if val is None: raise ConfigError(f"Required config key missing: {key}") return val - """), + """ + ), ] diff --git a/scripts/health_check/benchmark_get_all_latest_health_checks.py b/scripts/health_check/benchmark_get_all_latest_health_checks.py index 5d0957d4f7..45845554c8 100644 --- a/scripts/health_check/benchmark_get_all_latest_health_checks.py +++ b/scripts/health_check/benchmark_get_all_latest_health_checks.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Bench LiteLLM_HealthCheckTable + PrismaClient +Bench LiteLLM_HealthCheckTable + PrismaClient - set DATABASE_URL to your Postgres - Run ```prisma generate``` to install prisma client before running test ) - This test writes to the default "public" database. Make sure to run cleanup after testing diff --git a/tests/batches_tests/test_batch_custom_pricing.py b/tests/batches_tests/test_batch_custom_pricing.py index 911fd8e311..f4e84b46be 100644 --- a/tests/batches_tests/test_batch_custom_pricing.py +++ b/tests/batches_tests/test_batch_custom_pricing.py @@ -18,6 +18,7 @@ from litellm.batches.batch_utils import ( from litellm.cost_calculator import batch_cost_calculator from litellm.types.utils import Usage + # --- helpers --- diff --git a/tests/batches_tests/test_hosted_vllm_batches_and_files.py b/tests/batches_tests/test_hosted_vllm_batches_and_files.py index b7f21f6cab..c7a25c71c5 100644 --- a/tests/batches_tests/test_hosted_vllm_batches_and_files.py +++ b/tests/batches_tests/test_hosted_vllm_batches_and_files.py @@ -20,6 +20,7 @@ sys.path.insert(0, os.path.abspath("../..")) import litellm + SERVER_URL = "https://exampleopenaiendpoint-production-0ee2.up.railway.app/v1" diff --git a/tests/benchmarks/test_benchmarks.py b/tests/benchmarks/test_benchmarks.py index d5ade58980..123dad93e1 100644 --- a/tests/benchmarks/test_benchmarks.py +++ b/tests/benchmarks/test_benchmarks.py @@ -12,6 +12,7 @@ import litellm from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.litellm_core_utils.token_counter import token_counter + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- diff --git a/tests/code_coverage_tests/check_fastuuid_usage.py b/tests/code_coverage_tests/check_fastuuid_usage.py index bb3176b57a..e043345437 100644 --- a/tests/code_coverage_tests/check_fastuuid_usage.py +++ b/tests/code_coverage_tests/check_fastuuid_usage.py @@ -2,6 +2,7 @@ import ast import os from typing import List, Dict, Any + ALLOWED_FILE = os.path.normpath("litellm/_uuid.py") diff --git a/tests/enterprise/conftest.py b/tests/enterprise/conftest.py index 524ab85b93..0365bbbcfa 100644 --- a/tests/enterprise/conftest.py +++ b/tests/enterprise/conftest.py @@ -23,6 +23,8 @@ def event_loop(): loop.close() + + @pytest.fixture(scope="function", autouse=True) def setup_and_teardown(): """ diff --git a/tests/enterprise/litellm_enterprise/integrations/test_prometheus.py b/tests/enterprise/litellm_enterprise/integrations/test_prometheus.py index 6bdc808c4f..212c5d4a32 100644 --- a/tests/enterprise/litellm_enterprise/integrations/test_prometheus.py +++ b/tests/enterprise/litellm_enterprise/integrations/test_prometheus.py @@ -944,8 +944,8 @@ def test_callback_failure_metric_different_callbacks(prometheus_logger): async def test_langfuse_callback_failure_metric(prometheus_logger): """ Test that Langfuse callback failures are properly tracked in Prometheus metrics. - - This test verifies that when Langfuse logging fails, the + + This test verifies that when Langfuse logging fails, the litellm_callback_logging_failures_metric is incremented with callback_name="langfuse". """ from unittest.mock import MagicMock, patch @@ -957,20 +957,16 @@ async def test_langfuse_callback_failure_metric(prometheus_logger): # Get initial value initial_value = 0 try: - initial_value = ( - prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name="langfuse" - )._value.get() - ) + initial_value = prometheus_logger.litellm_callback_logging_failures_metric.labels( + callback_name="langfuse" + )._value.get() except Exception: initial_value = 0 - + # Create Langfuse logger with mocked initialization - with patch( - "litellm.integrations.langfuse.langfuse_prompt_management.langfuse_client_init" - ): + with patch("litellm.integrations.langfuse.langfuse_prompt_management.langfuse_client_init"): langfuse_logger = LangfusePromptManagement() - + # Mock the log_event_on_langfuse to raise an exception with patch( "litellm.integrations.langfuse.langfuse_prompt_management.LangFuseHandler.get_langfuse_logger_for_request" @@ -978,16 +974,14 @@ async def test_langfuse_callback_failure_metric(prometheus_logger): mock_logger = MagicMock() mock_logger.log_event_on_langfuse.side_effect = Exception("Langfuse API error") mock_get_logger.return_value = mock_logger - + # Mock handle_callback_failure to track calls - with patch.object( - prometheus_logger, "increment_callback_logging_failure" - ) as mock_increment: + with patch.object(prometheus_logger, "increment_callback_logging_failure") as mock_increment: # Inject prometheus logger into the langfuse logger - langfuse_logger.handle_callback_failure = ( - lambda callback_name: mock_increment(callback_name=callback_name) + langfuse_logger.handle_callback_failure = lambda callback_name: mock_increment( + callback_name=callback_name ) - + # Call async_log_success_event - should catch exception and increment metric await langfuse_logger.async_log_success_event( kwargs={}, @@ -995,10 +989,10 @@ async def test_langfuse_callback_failure_metric(prometheus_logger): start_time=None, end_time=None, ) - + # Verify that increment was called with correct callback name mock_increment.assert_called_once_with(callback_name="langfuse") - + print("✓ Langfuse callback failure metric test passed") @@ -1006,8 +1000,8 @@ async def test_langfuse_callback_failure_metric(prometheus_logger): async def test_langfuse_otel_callback_failure_metric(prometheus_logger): """ Test that Langfuse OTEL callback failures are properly tracked in Prometheus metrics. - - This test verifies that when Langfuse OTEL logging fails, the + + This test verifies that when Langfuse OTEL logging fails, the litellm_callback_logging_failures_metric is incremented with callback_name="langfuse_otel". """ from unittest.mock import MagicMock, patch @@ -1017,58 +1011,50 @@ async def test_langfuse_otel_callback_failure_metric(prometheus_logger): # Get initial value initial_value = 0 try: - initial_value = ( - prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name="langfuse_otel" - )._value.get() - ) + initial_value = prometheus_logger.litellm_callback_logging_failures_metric.labels( + callback_name="langfuse_otel" + )._value.get() except Exception: initial_value = 0 - + # Create Langfuse OTEL logger with mocked initialization - with patch( - "litellm.integrations.opentelemetry.OpenTelemetry.__init__", return_value=None - ): + with patch("litellm.integrations.opentelemetry.OpenTelemetry.__init__", return_value=None): langfuse_otel_logger = LangfuseOtelLogger(callback_name="langfuse_otel") langfuse_otel_logger.callback_name = "langfuse_otel" - + # Mock handle_callback_failure to track calls - with patch.object( - prometheus_logger, "increment_callback_logging_failure" - ) as mock_increment: + with patch.object(prometheus_logger, "increment_callback_logging_failure") as mock_increment: # Inject prometheus logger into the langfuse otel logger - langfuse_otel_logger.handle_callback_failure = ( - lambda callback_name: mock_increment(callback_name=callback_name) + langfuse_otel_logger.handle_callback_failure = lambda callback_name: mock_increment( + callback_name=callback_name ) - + # Test that the OpenTelemetry base class set_attributes exception handler works # This is where langfuse_otel failures are caught and tracked - with patch.object( - langfuse_otel_logger, "set_attributes" - ) as mock_set_attributes: + with patch.object(langfuse_otel_logger, "set_attributes") as mock_set_attributes: # Simulate the exception handling in set_attributes def set_attributes_with_error(*args, **kwargs): # This simulates what happens in the real set_attributes method try: raise Exception("Attribute error") except Exception as e: - langfuse_otel_logger.handle_callback_failure( - callback_name=langfuse_otel_logger.callback_name - ) - + langfuse_otel_logger.handle_callback_failure(callback_name=langfuse_otel_logger.callback_name) + mock_set_attributes.side_effect = set_attributes_with_error - + # Call set_attributes try: langfuse_otel_logger.set_attributes( - span=MagicMock(), kwargs={}, response_obj={} + span=MagicMock(), + kwargs={}, + response_obj={} ) except Exception: pass - + # Verify that increment was called with correct callback name mock_increment.assert_called_with(callback_name="langfuse_otel") - + print("✓ Langfuse OTEL callback failure metric test passed") diff --git a/tests/enterprise/litellm_enterprise/proxy/auth/test_user_api_key_auth.py b/tests/enterprise/litellm_enterprise/proxy/auth/test_user_api_key_auth.py index ab5c557662..a45df5df00 100644 --- a/tests/enterprise/litellm_enterprise/proxy/auth/test_user_api_key_auth.py +++ b/tests/enterprise/litellm_enterprise/proxy/auth/test_user_api_key_auth.py @@ -49,13 +49,10 @@ async def test_enterprise_custom_auth_returns_string(): mock_user_auth = AsyncMock(return_value="sk-test-key") request = MagicMock(spec=Request) - with ( - patch( - "litellm.proxy.auth.user_api_key_auth.enterprise_custom_auth", - mock_user_auth, - ), - patch("litellm.proxy.proxy_server.master_key", "sk-1234"), - patch("litellm.proxy.proxy_server.prisma_client", MagicMock()), + with patch( + "litellm.proxy.auth.user_api_key_auth.enterprise_custom_auth", mock_user_auth + ), patch("litellm.proxy.proxy_server.master_key", "sk-1234"), patch( + "litellm.proxy.proxy_server.prisma_client", MagicMock() ): # Verify the key is correctly handled in _user_api_key_auth_builder with patch( diff --git a/tests/enterprise/litellm_enterprise/proxy/guardrails/test_apply_guardrail_endpoint.py b/tests/enterprise/litellm_enterprise/proxy/guardrails/test_apply_guardrail_endpoint.py index b7fd665c41..0d27df50d1 100644 --- a/tests/enterprise/litellm_enterprise/proxy/guardrails/test_apply_guardrail_endpoint.py +++ b/tests/enterprise/litellm_enterprise/proxy/guardrails/test_apply_guardrail_endpoint.py @@ -105,9 +105,7 @@ async def test_apply_guardrail_endpoint_with_presidio_guardrail(): mock_guardrail = Mock(spec=CustomGuardrail) # Simulate masking PII entities - returns GenericGuardrailAPIInputs (dict with texts key) mock_guardrail.apply_guardrail = AsyncMock( - return_value={ - "texts": ["My name is [PERSON] and my email is [EMAIL_ADDRESS]"] - } + return_value={"texts": ["My name is [PERSON] and my email is [EMAIL_ADDRESS]"]} ) # Configure the registry to return our mock guardrail diff --git a/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py b/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py index 3f9c2486d9..9f4ca4ed10 100644 --- a/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py +++ b/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py @@ -95,9 +95,9 @@ async def test_async_pre_call_deployment_hook_resolves_model_id_from_litellm_met kwargs=kwargs, call_type=CallTypes.acreate_batch ) - assert ( - result["input_file_id"] == provider_file_id - ), f"Expected provider file ID '{provider_file_id}', got '{result['input_file_id']}'" + assert result["input_file_id"] == provider_file_id, ( + f"Expected provider file ID '{provider_file_id}', got '{result['input_file_id']}'" + ) @pytest.mark.asyncio @@ -134,9 +134,9 @@ async def test_async_pre_call_deployment_hook_prefers_top_level_model_info(): kwargs=kwargs, call_type=CallTypes.acreate_batch ) - assert ( - result["input_file_id"] == top_level_provider_file - ), "Should prefer top-level model_info over litellm_metadata" + assert result["input_file_id"] == top_level_provider_file, ( + "Should prefer top-level model_info over litellm_metadata" + ) @pytest.mark.asyncio @@ -162,9 +162,9 @@ async def test_async_pre_call_deployment_hook_no_model_info_leaves_file_id_uncha kwargs=kwargs, call_type=CallTypes.acreate_batch ) - assert ( - result["input_file_id"] == managed_file_id - ), "File ID should remain unchanged when model_info is not available" + assert result["input_file_id"] == managed_file_id, ( + "File ID should remain unchanged when model_info is not available" + ) # def test_list_managed_files(): @@ -341,9 +341,7 @@ async def test_async_pre_call_hook_for_unified_finetuning_job(): @pytest.mark.asyncio -@pytest.mark.parametrize( - "call_type", ["afile_content", "afile_delete", "afile_retrieve"] -) +@pytest.mark.parametrize("call_type", ["afile_content", "afile_delete", "afile_retrieve"]) async def test_can_user_call_unified_file_id(call_type): """ Test that on file retrieve, delete, and content we check if the user has access to the file @@ -603,7 +601,7 @@ async def test_error_file_id_for_failed_batch(): "litellm_model_name": "gpt-4o", "unified_batch_id": "litellm_proxy;model_id:test-model-id;llm_batch_id:batch_abc123", } - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=AsyncMock() ) @@ -622,11 +620,12 @@ async def test_error_file_id_for_failed_batch(): # Mock the afile_retrieve to simulate retrieving error file metadata with patch("litellm.afile_retrieve", new_callable=AsyncMock) as mock_retrieve: mock_retrieve.return_value = error_file_object - + user_api_key_dict = UserAPIKeyAuth( - user_id="test-user-123", parent_otel_span=MagicMock() + user_id="test-user-123", + parent_otel_span=MagicMock() ) - + response = await proxy_managed_files.async_post_call_success_hook( data={}, user_api_key_dict=user_api_key_dict, @@ -637,9 +636,7 @@ async def test_error_file_id_for_failed_batch(): assert cast(LiteLLMBatch, response).error_file_id is not None assert not cast(LiteLLMBatch, response).error_file_id.startswith("error-") # Verify it's a base64 encoded managed file ID - assert _is_base64_encoded_unified_file_id( - cast(LiteLLMBatch, response).error_file_id - ) + assert _is_base64_encoded_unified_file_id(cast(LiteLLMBatch, response).error_file_id) @pytest.mark.asyncio @@ -653,7 +650,7 @@ async def test_async_post_call_success_hook_twice_assert_no_unique_violation(): # Use AsyncMock instead of real database connection prisma_client = AsyncMock() - + batch = LiteLLMBatch( id="bGl0ZWxsbV9wcm94eTttb2RlbF9pZDoxMjM0NTY3OTtsbG1fYmF0Y2hfaWQ6YmF0Y2hfNjg1YzVlNWQ2Mzk4ODE5MGI4NWJkYjIxNDdiYTEzMWQ", completion_window="24h", @@ -681,10 +678,8 @@ async def test_async_post_call_success_hook_twice_assert_no_unique_violation(): # first retrieve batch tasks = [] first_create_task = asyncio.create_task - with patch("asyncio.create_task") as mock_create_task: - mock_create_task.side_effect = ( - lambda coro: tasks.append(first_create_task(coro)) or tasks[-1] - ) + with patch('asyncio.create_task') as mock_create_task: + mock_create_task.side_effect = lambda coro: tasks.append(first_create_task(coro)) or tasks[-1] response = await proxy_managed_files.async_post_call_success_hook( data={}, @@ -705,10 +700,8 @@ async def test_async_post_call_success_hook_twice_assert_no_unique_violation(): # second retrieve batch tasks = [] second_create_task = asyncio.create_task - with patch("asyncio.create_task") as mock_create_task: - mock_create_task.side_effect = ( - lambda coro: tasks.append(second_create_task(coro)) or tasks[-1] - ) + with patch('asyncio.create_task') as mock_create_task: + mock_create_task.side_effect = lambda coro: tasks.append(second_create_task(coro)) or tasks[-1] await proxy_managed_files.async_post_call_success_hook( data={}, @@ -735,7 +728,7 @@ def test_update_responses_input_with_unified_file_id(): # Create a base64-encoded unified file ID # This decodes to: litellm_proxy:application/pdf;unified_id,6c0b5890-8914-48e0-b8f4-0ae5ed3c14a5;target_model_names,gpt-4o;llm_output_file_id,file-ECBPW7ML9g7XHdwGgUPZaM;llm_output_file_model_id,e26453f9e76e7993680d0068d98c1f4cc205bbad0967a33c664893568ca743c2 unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - + # Test input with unified file ID in content array input_data = [ { @@ -752,18 +745,15 @@ def test_update_responses_input_with_unified_file_id(): ], } ] - + # Update the input updated_input = update_responses_input_with_model_file_ids(input=input_data) - + # Verify the file_id was updated to the provider-specific file ID assert updated_input[0]["content"][0]["type"] == "input_file" assert updated_input[0]["content"][0]["file_id"] == "file-ECBPW7ML9g7XHdwGgUPZaM" assert updated_input[0]["content"][1]["type"] == "input_text" - assert ( - updated_input[0]["content"][1]["text"] - == "What is the first dragon in the book?" - ) + assert updated_input[0]["content"][1]["text"] == "What is the first dragon in the book?" def test_update_responses_input_with_regular_file_id(): @@ -777,7 +767,7 @@ def test_update_responses_input_with_regular_file_id(): # Regular OpenAI file ID (not a unified file ID) regular_file_id = "file-abc123xyz" - + input_data = [ { "role": "user", @@ -793,10 +783,10 @@ def test_update_responses_input_with_regular_file_id(): ], } ] - + # Update the input updated_input = update_responses_input_with_model_file_ids(input=input_data) - + # Verify the file_id was kept unchanged (regular OpenAI file ID) assert updated_input[0]["content"][0]["type"] == "input_file" assert updated_input[0]["content"][0]["file_id"] == regular_file_id @@ -810,11 +800,11 @@ def test_update_responses_input_with_string_input(): from litellm.litellm_core_utils.prompt_templates.common_utils import ( update_responses_input_with_model_file_ids, ) - + input_data = "What is AI?" - + updated_input = update_responses_input_with_model_file_ids(input=input_data) - + assert updated_input == input_data assert isinstance(updated_input, str) @@ -832,7 +822,7 @@ def test_update_responses_input_with_multiple_file_ids(): unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" # Regular OpenAI file ID regular_file_id = "file-regular123" - + input_data = [ { "role": "user", @@ -852,9 +842,9 @@ def test_update_responses_input_with_multiple_file_ids(): ], } ] - + updated_input = update_responses_input_with_model_file_ids(input=input_data) - + # Verify unified file ID was updated assert updated_input[0]["content"][0]["file_id"] == "file-ECBPW7ML9g7XHdwGgUPZaM" # Verify regular file ID was kept unchanged @@ -874,7 +864,7 @@ def test_update_responses_input_with_model_file_id_mapping(): # Managed file ID (unified) managed_file_id = "litellm_proxy_file_123" - + # Model file ID mapping model_file_id_mapping = { managed_file_id: { @@ -882,7 +872,7 @@ def test_update_responses_input_with_model_file_id_mapping(): "model_id_2": "azure_file_xyz", } } - + input_data = [ { "role": "user", @@ -898,24 +888,24 @@ def test_update_responses_input_with_model_file_id_mapping(): ], } ] - + # Update input with model_id_1 mapping updated_input = update_responses_input_with_model_file_ids( input=input_data, model_id="model_id_1", model_file_id_mapping=model_file_id_mapping, ) - + # Verify the file_id was mapped to the correct provider-specific file ID assert updated_input[0]["content"][0]["file_id"] == "openai_file_abc" - + # Test with different model_id updated_input_2 = update_responses_input_with_model_file_ids( input=input_data, model_id="model_id_2", model_file_id_mapping=model_file_id_mapping, ) - + assert updated_input_2[0]["content"][0]["file_id"] == "azure_file_xyz" @@ -923,7 +913,7 @@ def test_update_responses_tools_with_model_file_id_mapping(): """ Test that update_responses_tools_with_model_file_ids correctly maps file IDs in code_interpreter tools with container.file_ids. - + This is a regression test for the issue where managed file IDs in tools.container.file_ids were not being replaced with provider-specific file IDs, causing "string too long" errors from OpenAI. @@ -935,7 +925,7 @@ def test_update_responses_tools_with_model_file_id_mapping(): # Managed file IDs managed_file_id_1 = "litellm_proxy_file_123" managed_file_id_2 = "litellm_proxy_file_456" - + # Model file ID mapping model_file_id_mapping = { managed_file_id_1: { @@ -945,7 +935,7 @@ def test_update_responses_tools_with_model_file_id_mapping(): "model_id_1": "openai_file_def", }, } - + tools = [ { "type": "code_interpreter", @@ -955,20 +945,17 @@ def test_update_responses_tools_with_model_file_id_mapping(): }, } ] - + # Update tools with model mapping updated_tools = update_responses_tools_with_model_file_ids( tools=tools, model_id="model_id_1", model_file_id_mapping=model_file_id_mapping, ) - + # Verify the file IDs were mapped to provider-specific file IDs assert updated_tools[0]["type"] == "code_interpreter" - assert updated_tools[0]["container"]["file_ids"] == [ - "openai_file_abc", - "openai_file_def", - ] + assert updated_tools[0]["container"]["file_ids"] == ["openai_file_abc", "openai_file_def"] def test_update_responses_tools_without_mapping(): @@ -981,7 +968,7 @@ def test_update_responses_tools_without_mapping(): ) regular_file_id = "file-abc123" - + tools = [ { "type": "code_interpreter", @@ -991,14 +978,14 @@ def test_update_responses_tools_without_mapping(): }, } ] - + # Update tools without mapping updated_tools = update_responses_tools_with_model_file_ids( tools=tools, model_id=None, model_file_id_mapping=None, ) - + # Verify the file ID was kept unchanged assert updated_tools[0]["container"]["file_ids"] == [regular_file_id] @@ -1014,13 +1001,13 @@ def test_update_responses_tools_with_mixed_file_ids(): managed_file_id = "litellm_proxy_file_123" regular_file_id = "file-abc123" - + model_file_id_mapping = { managed_file_id: { "model_id_1": "openai_file_abc", }, } - + tools = [ { "type": "code_interpreter", @@ -1030,19 +1017,16 @@ def test_update_responses_tools_with_mixed_file_ids(): }, } ] - + # Update tools updated_tools = update_responses_tools_with_model_file_ids( tools=tools, model_id="model_id_1", model_file_id_mapping=model_file_id_mapping, ) - + # Verify managed file ID was mapped and regular file ID was kept - assert updated_tools[0]["container"]["file_ids"] == [ - "openai_file_abc", - regular_file_id, - ] + assert updated_tools[0]["container"]["file_ids"] == ["openai_file_abc", regular_file_id] def test_get_file_ids_from_responses_tools(): @@ -1053,7 +1037,7 @@ def test_get_file_ids_from_responses_tools(): proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=MagicMock() ) - + tools = [ { "type": "code_interpreter", @@ -1063,9 +1047,9 @@ def test_get_file_ids_from_responses_tools(): }, } ] - + file_ids = proxy_managed_files.get_file_ids_from_responses_tools(tools) - + assert file_ids == ["file-123", "file-456"] @@ -1076,7 +1060,7 @@ def test_get_file_ids_from_responses_tools_multiple_tools(): proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=MagicMock() ) - + tools = [ { "type": "code_interpreter", @@ -1096,9 +1080,9 @@ def test_get_file_ids_from_responses_tools_multiple_tools(): }, }, ] - + file_ids = proxy_managed_files.get_file_ids_from_responses_tools(tools) - + # Should extract file IDs only from code_interpreter tools assert file_ids == ["file-123", "file-456", "file-789"] @@ -1110,15 +1094,15 @@ def test_get_file_ids_from_responses_tools_empty(): proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=MagicMock() ) - + # Test with None file_ids = proxy_managed_files.get_file_ids_from_responses_tools(None) assert file_ids == [] - + # Test with empty list file_ids = proxy_managed_files.get_file_ids_from_responses_tools([]) assert file_ids == [] - + # Test with tools without file_ids tools = [{"type": "file_search"}] file_ids = proxy_managed_files.get_file_ids_from_responses_tools(tools) @@ -1135,30 +1119,30 @@ async def test_check_file_ids_access_with_unified_file_ids(): # Create a unified file ID unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" regular_file_id = "file-abc123" - + # Mock the access check to return True prisma_client = AsyncMock() internal_usage_cache = MagicMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock can_user_call_unified_file_id to return True proxy_managed_files.can_user_call_unified_file_id = AsyncMock(return_value=True) - + user_api_key_dict = UserAPIKeyAuth( user_id="test_user_123", parent_otel_span=MagicMock(), ) - + # Should not raise an exception for accessible files await proxy_managed_files.check_file_ids_access( [unified_file_id, regular_file_id], user_api_key_dict, ) - + # Verify can_user_call_unified_file_id was called for the unified file ID proxy_managed_files.can_user_call_unified_file_id.assert_called_once_with( unified_file_id, user_api_key_dict @@ -1171,32 +1155,32 @@ async def test_check_file_ids_access_denied(): Test that check_file_ids_access raises HTTPException when user doesn't have access. """ from litellm.proxy._types import UserAPIKeyAuth - + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - + prisma_client = AsyncMock() internal_usage_cache = MagicMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock can_user_call_unified_file_id to return False (access denied) proxy_managed_files.can_user_call_unified_file_id = AsyncMock(return_value=False) - + user_api_key_dict = UserAPIKeyAuth( user_id="test_user_123", parent_otel_span=MagicMock(), ) - + # Should raise HTTPException with 403 status code with pytest.raises(HTTPException) as exc_info: await proxy_managed_files.check_file_ids_access( [unified_file_id], user_api_key_dict, ) - + assert exc_info.value.status_code == 403 assert "does not have access to the file" in exc_info.value.detail @@ -1207,32 +1191,32 @@ async def test_check_file_ids_access_with_regular_files_only(): Test that check_file_ids_access doesn't check access for regular (non-unified) file IDs. """ from litellm.proxy._types import UserAPIKeyAuth - + regular_file_id_1 = "file-abc123" regular_file_id_2 = "file-xyz789" - + prisma_client = AsyncMock() internal_usage_cache = MagicMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock can_user_call_unified_file_id (should not be called for regular files) proxy_managed_files.can_user_call_unified_file_id = AsyncMock() - + user_api_key_dict = UserAPIKeyAuth( user_id="test_user_123", parent_otel_span=MagicMock(), ) - + # Should not raise exception and should not call can_user_call_unified_file_id await proxy_managed_files.check_file_ids_access( [regular_file_id_1, regular_file_id_2], user_api_key_dict, ) - + # Verify can_user_call_unified_file_id was NOT called proxy_managed_files.can_user_call_unified_file_id.assert_not_called() @@ -1243,31 +1227,31 @@ async def test_completion_with_file_access_check(): Test that completion call type checks file access before processing. """ from litellm.proxy._types import UserAPIKeyAuth - + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - + prisma_client = AsyncMock() prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(return_value=None) - + internal_usage_cache = MagicMock() internal_usage_cache.async_get_cache = AsyncMock(return_value=None) - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock the get_model_file_id_mapping to return empty dict proxy_managed_files.get_model_file_id_mapping = AsyncMock(return_value={}) - + # Mock access check to allow access proxy_managed_files.can_user_call_unified_file_id = AsyncMock(return_value=True) - + user_api_key_dict = UserAPIKeyAuth( user_id="test_user_123", parent_otel_span=MagicMock(), ) - + data = { "messages": [ { @@ -1283,7 +1267,7 @@ async def test_completion_with_file_access_check(): ], "model": "gpt-4", } - + # Should not raise exception result = await proxy_managed_files.async_pre_call_hook( user_api_key_dict=user_api_key_dict, @@ -1291,7 +1275,7 @@ async def test_completion_with_file_access_check(): data=data, call_type="acompletion", ) - + # Verify access check was called proxy_managed_files.can_user_call_unified_file_id.assert_called_once() @@ -1302,32 +1286,32 @@ async def test_responses_with_file_access_check(): Test that responses API checks file access for files in both input and tools. """ from litellm.proxy._types import UserAPIKeyAuth - + unified_file_id_1 = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" unified_file_id_2 = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsNzc3Nzc3Nzc7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1YWVo7bGxtX291dHB1dF9maWxlX21vZGVsX2lkLG1vZGVsXzEyMw" - + prisma_client = AsyncMock() prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(return_value=None) - + internal_usage_cache = MagicMock() internal_usage_cache.async_get_cache = AsyncMock(return_value=None) - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock the get_model_file_id_mapping to return empty dict proxy_managed_files.get_model_file_id_mapping = AsyncMock(return_value={}) - + # Mock access check to allow access proxy_managed_files.can_user_call_unified_file_id = AsyncMock(return_value=True) - + user_api_key_dict = UserAPIKeyAuth( user_id="test_user_123", parent_otel_span=MagicMock(), ) - + data = { "input": [ { @@ -1349,7 +1333,7 @@ async def test_responses_with_file_access_check(): ], "model": "gpt-4", } - + # Should not raise exception result = await proxy_managed_files.async_pre_call_hook( user_api_key_dict=user_api_key_dict, @@ -1357,7 +1341,7 @@ async def test_responses_with_file_access_check(): data=data, call_type="aresponses", ) - + # Verify access check was called for both file IDs assert proxy_managed_files.can_user_call_unified_file_id.call_count == 2 @@ -1369,19 +1353,17 @@ async def test_store_unified_file_id_with_none_file_object(): (e.g., for batch output files that are stored before file metadata is available). """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - prisma_client.db.litellm_managedfiletable.create = AsyncMock( - return_value=MagicMock() - ) + prisma_client.db.litellm_managedfiletable.create = AsyncMock(return_value=MagicMock()) internal_usage_cache = MagicMock() internal_usage_cache.async_set_cache = AsyncMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Store with file_object=None (simulating batch output file storage) await proxy_managed_files.store_unified_file_id( file_id="test-unified-file-id", @@ -1390,7 +1372,7 @@ async def test_store_unified_file_id_with_none_file_object(): model_mappings={"model-123": "file-provider-xyz"}, user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), ) - + # Verify DB create was called with expected data (without file_object) prisma_client.db.litellm_managedfiletable.create.assert_called_once() call_args = prisma_client.db.litellm_managedfiletable.create.call_args @@ -1405,38 +1387,34 @@ async def test_afile_delete_returns_provider_response_when_stored_file_object_no stored file_object is None (e.g., for batch output files). """ from litellm.types.llms.openai import OpenAIFileObject - + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsdGVzdC1pZDt0YXJnZXRfbW9kZWxfbmFtZXMsZ3B0LTRvO2xsbV9vdXRwdXRfZmlsZV9pZCxmaWxlLXByb3ZpZGVyLXh5ejtsbG1fb3V0cHV0X2ZpbGVfbW9kZWxfaWQsbW9kZWwtMTIz" - + prisma_client = AsyncMock() db_record = MagicMock() db_record.model_mappings = '{"model-123": "file-provider-xyz"}' - prisma_client.db.litellm_managedfiletable.find_first = AsyncMock( - return_value=db_record - ) + prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(return_value=db_record) prisma_client.db.litellm_managedfiletable.delete = AsyncMock() - + internal_usage_cache = MagicMock() - internal_usage_cache.async_get_cache = AsyncMock( - return_value={ - "unified_file_id": unified_file_id, - "model_mappings": {"model-123": "file-provider-xyz"}, - "flat_model_file_ids": ["file-provider-xyz"], - "file_object": None, - "created_by": "test-user", - "updated_by": "test-user", - } - ) + internal_usage_cache.async_get_cache = AsyncMock(return_value={ + "unified_file_id": unified_file_id, + "model_mappings": {"model-123": "file-provider-xyz"}, + "flat_model_file_ids": ["file-provider-xyz"], + "file_object": None, + "created_by": "test-user", + "updated_by": "test-user", + }) internal_usage_cache.async_set_cache = AsyncMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock the delete_unified_file_id to return None (simulating file_object=None) proxy_managed_files.delete_unified_file_id = AsyncMock(return_value=None) - + # Mock router response provider_delete_response = OpenAIFileObject( id="file-provider-xyz", @@ -1446,16 +1424,16 @@ async def test_afile_delete_returns_provider_response_when_stored_file_object_no filename="test.jsonl", purpose="batch", ) - + mock_router = MagicMock() mock_router.afile_delete = AsyncMock(return_value=provider_delete_response) - + result = await proxy_managed_files.afile_delete( file_id=unified_file_id, litellm_parent_otel_span=None, llm_router=mock_router, ) - + # Should return the provider response with the unified file ID assert result is not None assert result.id == unified_file_id @@ -1468,21 +1446,21 @@ async def test_afile_retrieve_fetches_from_provider_when_file_object_none(): file_object is None (e.g., for batch output files). """ from litellm.types.llms.openai import OpenAIFileObject - + prisma_client = AsyncMock() internal_usage_cache = MagicMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock get_unified_file_id to return a stored object with file_object=None stored_file = MagicMock() stored_file.file_object = None stored_file.model_mappings = {"model-123": "file-provider-xyz"} proxy_managed_files.get_unified_file_id = AsyncMock(return_value=stored_file) - + # Mock the router and provider response provider_file_response = OpenAIFileObject( id="file-provider-xyz", @@ -1492,25 +1470,23 @@ async def test_afile_retrieve_fetches_from_provider_when_file_object_none(): filename="output.jsonl", purpose="batch_output", ) - + mock_router = MagicMock() - mock_router.get_deployment_credentials_with_provider = MagicMock( - return_value={ - "api_key": "test-key", - "api_base": "https://api.openai.com", - } - ) - + mock_router.get_deployment_credentials_with_provider = MagicMock(return_value={ + "api_key": "test-key", + "api_base": "https://api.openai.com", + }) + with patch("litellm.afile_retrieve", new_callable=AsyncMock) as mock_afile_retrieve: mock_afile_retrieve.return_value = provider_file_response - + unified_file_id = "test-unified-file-id" result = await proxy_managed_files.afile_retrieve( file_id=unified_file_id, litellm_parent_otel_span=None, llm_router=mock_router, ) - + # Should return the provider response with the unified file ID assert result is not None assert result.id == unified_file_id @@ -1525,27 +1501,27 @@ async def test_afile_retrieve_raises_error_when_no_router_and_file_object_none() """ prisma_client = AsyncMock() internal_usage_cache = MagicMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock get_unified_file_id to return a stored object with file_object=None stored_file = MagicMock() stored_file.file_object = None stored_file.model_mappings = {"model-123": "file-provider-xyz"} proxy_managed_files.get_unified_file_id = AsyncMock(return_value=stored_file) - + unified_file_id = "test-unified-file-id" - + with pytest.raises(Exception) as exc_info: await proxy_managed_files.afile_retrieve( file_id=unified_file_id, litellm_parent_otel_span=None, llm_router=None, ) - + assert "llm_router is required" in str(exc_info.value) @@ -1556,15 +1532,15 @@ async def test_afile_retrieve_returns_stored_file_object_when_exists(): (the normal case for user-uploaded files). """ from litellm.types.llms.openai import OpenAIFileObject - + prisma_client = AsyncMock() internal_usage_cache = MagicMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock get_unified_file_id to return a stored object WITH file_object stored_file_object = OpenAIFileObject( id="test-unified-file-id", @@ -1577,13 +1553,13 @@ async def test_afile_retrieve_returns_stored_file_object_when_exists(): stored_file = MagicMock() stored_file.file_object = stored_file_object proxy_managed_files.get_unified_file_id = AsyncMock(return_value=stored_file) - + result = await proxy_managed_files.afile_retrieve( file_id="test-unified-file-id", litellm_parent_otel_span=None, llm_router=None, ) - + # Should return the stored file object directly assert result == stored_file_object @@ -1596,21 +1572,21 @@ async def test_afile_retrieve_raises_error_for_non_managed_file(): """ prisma_client = AsyncMock() internal_usage_cache = MagicMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( internal_usage_cache=internal_usage_cache, prisma_client=prisma_client, ) - + # Mock get_unified_file_id to return None (file not found) proxy_managed_files.get_unified_file_id = AsyncMock(return_value=None) - + with pytest.raises(Exception) as exc_info: await proxy_managed_files.afile_retrieve( file_id="non-existent-file-id", litellm_parent_otel_span=None, ) - + assert "not found" in str(exc_info.value) @@ -1621,58 +1597,54 @@ async def test_list_batches_from_managed_objects_table(): from litellm.proxy._types import UserAPIKeyAuth prisma_client = AsyncMock() - + batch_record_1 = MagicMock() batch_record_1.unified_object_id = "unified-batch-id-1" - batch_record_1.file_object = json.dumps( - { - "id": "batch_abc123", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "completed", - "created_at": 1234567890, - "input_file_id": "file-input-1", - "request_counts": {"total": 1, "completed": 1, "failed": 0}, - } - ) - + batch_record_1.file_object = json.dumps({ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "status": "completed", + "created_at": 1234567890, + "input_file_id": "file-input-1", + "request_counts": {"total": 1, "completed": 1, "failed": 0}, + }) + batch_record_2 = MagicMock() batch_record_2.unified_object_id = "unified-batch-id-2" - batch_record_2.file_object = json.dumps( - { - "id": "batch_xyz789", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "in_progress", - "created_at": 1234567891, - "input_file_id": "file-input-2", - "request_counts": {"total": 5, "completed": 2, "failed": 0}, - } - ) - + batch_record_2.file_object = json.dumps({ + "id": "batch_xyz789", + "object": "batch", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "status": "in_progress", + "created_at": 1234567891, + "input_file_id": "file-input-2", + "request_counts": {"total": 5, "completed": 2, "failed": 0}, + }) + prisma_client.db.litellm_managedobjecttable.find_many.return_value = [ batch_record_1, batch_record_2, ] - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - + result = await proxy_managed_files.list_user_batches( user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), limit=10, ) - + assert result["object"] == "list" assert len(result["data"]) == 2 assert result["data"][0].id == "unified-batch-id-1" assert result["data"][1].id == "unified-batch-id-2" assert result["first_id"] == "unified-batch-id-1" assert result["last_id"] == "unified-batch-id-2" - + # Should filter by user_id (created_by) prisma_client.db.litellm_managedobjecttable.find_many.assert_called_once_with( where={"file_purpose": "batch", "created_by": "test-user"}, @@ -1687,21 +1659,21 @@ async def test_list_batches_from_managed_objects_table_empty_list(): prisma_client = AsyncMock() prisma_client.db.litellm_managedobjecttable.find_many.return_value = [] - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - + result = await proxy_managed_files.list_user_batches( user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), ) - + assert result["object"] == "list" assert len(result["data"]) == 0 assert result["first_id"] is None assert result["last_id"] is None assert result["has_more"] is False - + # Verify where clause includes created_by filter # Default take is 20 when no limit is provided prisma_client.db.litellm_managedobjecttable.find_many.assert_called_once_with( @@ -1713,7 +1685,6 @@ async def test_list_batches_from_managed_objects_table_empty_list(): def _create_unified_batch_id(model_id: str, batch_id: str) -> str: import base64 - unified_str = f"litellm_proxy;model_id:{model_id};llm_batch_id:{batch_id}" return base64.urlsafe_b64encode(unified_str.encode()).decode().rstrip("=") @@ -1723,11 +1694,11 @@ async def test_list_batches_from_managed_objects_table_provider_filter_raises_ex from litellm.proxy._types import UserAPIKeyAuth prisma_client = AsyncMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - + # Filtering by provider should raise Exception with pytest.raises(Exception) as exc_info: await proxy_managed_files.list_user_batches( @@ -1735,11 +1706,11 @@ async def test_list_batches_from_managed_objects_table_provider_filter_raises_ex limit=10, provider="openai", ) - + assert str(exc_info.value) == ( "Filtering by 'provider' is not supported when using managed batches." ) - + # Verify find_many was NOT called since exception is raised before database query prisma_client.db.litellm_managedobjecttable.find_many.assert_not_called() @@ -1749,7 +1720,7 @@ async def test_list_batches_from_managed_objects_table_target_model_name_filter_ from litellm.proxy._types import UserAPIKeyAuth prisma_client = AsyncMock() - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) @@ -1761,64 +1732,59 @@ async def test_list_batches_from_managed_objects_table_target_model_name_filter_ limit=10, target_model_names="gpt-4o,gpt-3.5", ) - + assert str(exc_info.value) == ( "Filtering by 'target_model_names' is not supported when using managed batches." ) - + # Verify find_many was NOT called since exception is raised before database query prisma_client.db.litellm_managedobjecttable.find_many.assert_not_called() - @pytest.mark.asyncio async def test_list_batches_from_managed_objects_table_filters_by_created_by(): from litellm.proxy._types import UserAPIKeyAuth prisma_client = AsyncMock() - + # Create batch for user1 batch_user1 = MagicMock() batch_user1.unified_object_id = "unified-batch-user1" - batch_user1.file_object = json.dumps( - { - "id": "batch_user1_abc", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "completed", - "created_at": 1234567890, - "input_file_id": "file-input-user1", - "request_counts": {"total": 1, "completed": 1, "failed": 0}, - } - ) - + batch_user1.file_object = json.dumps({ + "id": "batch_user1_abc", + "object": "batch", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "status": "completed", + "created_at": 1234567890, + "input_file_id": "file-input-user1", + "request_counts": {"total": 1, "completed": 1, "failed": 0}, + }) + # Create batch for user2 batch_user2 = MagicMock() batch_user2.unified_object_id = "unified-batch-user2" - batch_user2.file_object = json.dumps( - { - "id": "batch_user2_xyz", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "completed", - "created_at": 1234567891, - "input_file_id": "file-input-user2", - "request_counts": {"total": 2, "completed": 2, "failed": 0}, - } - ) - + batch_user2.file_object = json.dumps({ + "id": "batch_user2_xyz", + "object": "batch", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "status": "completed", + "created_at": 1234567891, + "input_file_id": "file-input-user2", + "request_counts": {"total": 2, "completed": 2, "failed": 0}, + }) + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - + # Query with user1's API key - should only return user1's batch prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user1] result_user1 = await proxy_managed_files.list_user_batches( user_api_key_dict=UserAPIKeyAuth(user_id="user1"), limit=10, ) - + assert len(result_user1["data"]) == 1 assert result_user1["data"][0].id == "unified-batch-user1" prisma_client.db.litellm_managedobjecttable.find_many.assert_called_with( @@ -1826,14 +1792,14 @@ async def test_list_batches_from_managed_objects_table_filters_by_created_by(): take=10, order={"created_at": "desc"}, ) - + # Query with user2's API key - should only return user2's batch prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user2] result_user2 = await proxy_managed_files.list_user_batches( user_api_key_dict=UserAPIKeyAuth(user_id="user2"), limit=10, ) - + assert len(result_user2["data"]) == 1 assert result_user2["data"][0].id == "unified-batch-user2" prisma_client.db.litellm_managedobjecttable.find_many.assert_called_with( @@ -1856,7 +1822,7 @@ async def test_return_unified_file_id_includes_expires_at(): filename="test.jsonl", purpose="batch", status="uploaded", - expires_at=1234657890, + expires_at=1234657890, ) file_object._hidden_params = {"model_id": "test-model-id"} @@ -1896,27 +1862,25 @@ async def test_return_unified_file_id_includes_expires_at(): async def test_user_b_cannot_retrieve_user_a_batch(): """ Test that User B cannot retrieve a batch created by User A. - + This verifies batch isolation between users at the database/hook level. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Mock database to return User A as the creator batch_record = MagicMock() batch_record.created_by = "user_a_id" prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - + # User B tries to retrieve User A's batch - unified_batch_id = ( - "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" - ) - + unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" + with pytest.raises(HTTPException) as exc_info: await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( @@ -1926,7 +1890,7 @@ async def test_user_b_cannot_retrieve_user_a_batch(): data={"batch_id": unified_batch_id}, call_type="aretrieve_batch", ) - + # Should raise 403 Permission Denied assert exc_info.value.status_code == 403 @@ -1937,23 +1901,21 @@ async def test_user_b_cannot_cancel_user_a_batch(): Test that User B cannot cancel a batch created by User A. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Mock database to return User A as the creator batch_record = MagicMock() batch_record.created_by = "user_a_id" prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - + # User B tries to cancel User A's batch - unified_batch_id = ( - "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" - ) - + unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" + with pytest.raises(HTTPException) as exc_info: await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( @@ -1963,7 +1925,7 @@ async def test_user_b_cannot_cancel_user_a_batch(): data={"batch_id": unified_batch_id}, call_type="acancel_batch", ) - + # Should raise 403 Permission Denied assert exc_info.value.status_code == 403 @@ -1972,28 +1934,26 @@ async def test_user_b_cannot_cancel_user_a_batch(): async def test_user_a_can_retrieve_own_batch(): """ Test that User A can successfully retrieve their own batch. - + This is a positive test case to ensure permission checks don't block legitimate access. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Mock database to return User A as the creator batch_record = MagicMock() batch_record.created_by = "user_a_id" prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - + # User A retrieves their own batch - unified_batch_id = ( - "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" - ) - + unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" + # Should not raise an exception result = await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( @@ -2003,7 +1963,7 @@ async def test_user_a_can_retrieve_own_batch(): data={"batch_id": unified_batch_id}, call_type="aretrieve_batch", ) - + # Should successfully return the decoded batch_id assert "batch_id" in result assert result["model"] == "my-model" @@ -2015,23 +1975,21 @@ async def test_user_b_cannot_retrieve_user_a_file(): Test that User B cannot retrieve a file created by User A. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Mock database to return User A as the creator file_record = MagicMock() file_record.created_by = "user_a_id" prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( MagicMock(), prisma_client=prisma_client ) - + # User B tries to retrieve User A's file - unified_file_id = ( - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" - ) - + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" + with pytest.raises(HTTPException) as exc_info: await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( @@ -2041,7 +1999,7 @@ async def test_user_b_cannot_retrieve_user_a_file(): data={"file_id": unified_file_id}, call_type="afile_retrieve", ) - + # Should raise 403 Permission Denied assert exc_info.value.status_code == 403 @@ -2052,23 +2010,21 @@ async def test_user_b_cannot_download_user_a_file_content(): Test that User B cannot download file content for User A's file. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Mock database to return User A as the creator file_record = MagicMock() file_record.created_by = "user_a_id" prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( MagicMock(), prisma_client=prisma_client ) - + # User B tries to download User A's file content - unified_file_id = ( - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" - ) - + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" + with pytest.raises(HTTPException) as exc_info: await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( @@ -2078,7 +2034,7 @@ async def test_user_b_cannot_download_user_a_file_content(): data={"file_id": unified_file_id}, call_type="afile_content", ) - + # Should raise 403 Permission Denied assert exc_info.value.status_code == 403 @@ -2089,23 +2045,21 @@ async def test_user_b_cannot_delete_user_a_file(): Test that User B cannot delete a file created by User A. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Mock database to return User A as the creator file_record = MagicMock() file_record.created_by = "user_a_id" prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( MagicMock(), prisma_client=prisma_client ) - + # User B tries to delete User A's file - unified_file_id = ( - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" - ) - + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" + with pytest.raises(HTTPException) as exc_info: await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( @@ -2115,7 +2069,7 @@ async def test_user_b_cannot_delete_user_a_file(): data={"file_id": unified_file_id}, call_type="afile_delete", ) - + # Should raise 403 Permission Denied assert exc_info.value.status_code == 403 @@ -2124,38 +2078,34 @@ async def test_user_b_cannot_delete_user_a_file(): async def test_user_a_can_retrieve_own_file(): """ Test that User A can successfully retrieve their own file. - + Positive test case to ensure permission checks work correctly for the owner. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Mock database to return User A as the creator file_record = MagicMock() file_record.created_by = "user_a_id" file_record.model_mappings = '{"model-123": "file-abc123"}' - file_record.file_object = json.dumps( - { - "id": "file-abc123", - "object": "file", - "bytes": 1234, - "created_at": 1234567890, - "filename": "test.jsonl", - "purpose": "batch", - } - ) + file_record.file_object = json.dumps({ + "id": "file-abc123", + "object": "file", + "bytes": 1234, + "created_at": 1234567890, + "filename": "test.jsonl", + "purpose": "batch", + }) prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( MagicMock(), prisma_client=prisma_client ) - + # User A retrieves their own file - unified_file_id = ( - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" - ) - + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" + # Should not raise an exception result = await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( @@ -2165,7 +2115,7 @@ async def test_user_a_can_retrieve_own_file(): data={"file_id": unified_file_id}, call_type="afile_retrieve", ) - + # Should successfully return the decoded file_id assert "file_id" in result @@ -2174,46 +2124,44 @@ async def test_user_a_can_retrieve_own_file(): async def test_list_batches_only_returns_user_own_batches(): """ Test that list_user_batches only returns batches created by the requesting user. - + This ensures users cannot see other users' batches in list operations. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Create batches for User A batch_user_a = MagicMock() batch_user_a.unified_object_id = "batch-user-a" - batch_user_a.file_object = json.dumps( - { - "id": "batch_a", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "completed", - "created_at": 1234567890, - "input_file_id": "file-a", - "request_counts": {"total": 1, "completed": 1, "failed": 0}, - } - ) - + batch_user_a.file_object = json.dumps({ + "id": "batch_a", + "object": "batch", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "status": "completed", + "created_at": 1234567890, + "input_file_id": "file-a", + "request_counts": {"total": 1, "completed": 1, "failed": 0}, + }) + # Mock database to only return User A's batches prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user_a] - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - + # User A requests their batches result = await proxy_managed_files.list_user_batches( user_api_key_dict=UserAPIKeyAuth(user_id="user_a_id"), limit=10, ) - + # Should only return User A's batches assert len(result["data"]) == 1 assert result["data"][0].id == "batch-user-a" - + # Verify the database query filtered by user_id prisma_client.db.litellm_managedobjecttable.find_many.assert_called_once_with( where={"file_purpose": "batch", "created_by": "user_a_id"}, @@ -2226,49 +2174,51 @@ async def test_list_batches_only_returns_user_own_batches(): async def test_same_user_different_keys_can_access_batch(): """ Test that different API keys for the same user can access the same batch. - + This verifies that permission checks are based on user_id, not API key, allowing users to have multiple keys that can all access their resources. """ from litellm.proxy._types import UserAPIKeyAuth - + prisma_client = AsyncMock() - + # Mock database to return the user_id as creator batch_record = MagicMock() batch_record.created_by = "user_a_id" prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record - + proxy_managed_files = _PROXY_LiteLLMManagedFiles( DualCache(), prisma_client=prisma_client ) - - unified_batch_id = ( - "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" - ) - + + unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" + # First API key for User A retrieves the batch result1 = await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( - user_id="user_a_id", api_key="key-1", parent_otel_span=MagicMock() + user_id="user_a_id", + api_key="key-1", + parent_otel_span=MagicMock() ), cache=MagicMock(), data={"batch_id": unified_batch_id}, call_type="aretrieve_batch", ) - + assert "batch_id" in result1 - + # Second API key for the same User A retrieves the batch result2 = await proxy_managed_files.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth( - user_id="user_a_id", api_key="key-2", parent_otel_span=MagicMock() + user_id="user_a_id", + api_key="key-2", + parent_otel_span=MagicMock() ), cache=MagicMock(), data={"batch_id": unified_batch_id}, call_type="aretrieve_batch", ) - + assert "batch_id" in result2 # Both keys should get the same result assert result1["batch_id"] == result2["batch_id"] diff --git a/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_internal_user_endpoints.py b/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_internal_user_endpoints.py index 699ab8d1c4..69c3b4cb59 100644 --- a/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_internal_user_endpoints.py +++ b/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_internal_user_endpoints.py @@ -31,16 +31,12 @@ class TestAvailableEnterpriseUsers: self, client, mock_user_api_key_auth ): """Test when max_users is set and user count is within limit""" - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.proxy_server.premium_user", - True, - ), - patch( - "litellm.proxy.proxy_server.premium_user_data", - {"max_users": 10}, - ), + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.proxy_server.premium_user", + True, + ), patch( + "litellm.proxy.proxy_server.premium_user_data", + {"max_users": 10}, ): # Mock database count mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=5) @@ -70,16 +66,12 @@ class TestAvailableEnterpriseUsers: self, client, mock_user_api_key_auth ): """Test when max_users is not set (premium_user_data is None or doesn't contain max_users)""" - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.proxy_server.premium_user", - True, - ), - patch( - "litellm.proxy.proxy_server.premium_user_data", - None, - ), + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.proxy_server.premium_user", + True, + ), patch( + "litellm.proxy.proxy_server.premium_user_data", + None, ): # Mock database count mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=3) @@ -107,16 +99,12 @@ class TestAvailableEnterpriseUsers: self, client, mock_user_api_key_auth ): """Test the current bug where total_users_remaining can be negative""" - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.proxy_server.premium_user", - True, - ), - patch( - "litellm.proxy.proxy_server.premium_user_data", - {"key": "value"}, - ), + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.proxy_server.premium_user", + True, + ), patch( + "litellm.proxy.proxy_server.premium_user_data", + {"key": "value"}, ): # Mock database count higher than max_users to trigger the bug mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=8) @@ -152,15 +140,12 @@ class TestAvailableEnterpriseUsers: """Test when prisma_client is None (no database connection)""" from litellm.proxy._types import CommonProxyErrors - with ( - patch( - "litellm.proxy.proxy_server.prisma_client", - None, - ), - patch( - "litellm.proxy.proxy_server.premium_user", - True, - ), + with patch( + "litellm.proxy.proxy_server.prisma_client", + None, + ), patch( + "litellm.proxy.proxy_server.premium_user", + True, ): # Override the dependency client.app.dependency_overrides[mock_user_api_key_auth] = lambda: { diff --git a/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_project_endpoints_prisma.py b/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_project_endpoints_prisma.py index 8e0a3655b5..52cb94ff34 100644 --- a/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_project_endpoints_prisma.py +++ b/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_project_endpoints_prisma.py @@ -801,9 +801,7 @@ async def test_list_projects_returns_timestamps(): from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch - from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - list_projects, - ) + from litellm_enterprise.proxy.management_endpoints.project_endpoints import list_projects from litellm.proxy._types import LiteLLM_ProjectTable now = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) diff --git a/tests/guardrails_tests/test_akto_guardrails.py b/tests/guardrails_tests/test_akto_guardrails.py index 488899d831..83421f1313 100644 --- a/tests/guardrails_tests/test_akto_guardrails.py +++ b/tests/guardrails_tests/test_akto_guardrails.py @@ -14,6 +14,7 @@ from litellm.proxy.guardrails.guardrail_registry import ( ) from litellm.proxy.guardrails.guardrail_hooks.akto.akto import AktoGuardrail + # --------------------------------------------------------------------------- # Registry tests # --------------------------------------------------------------------------- diff --git a/tests/guardrails_tests/test_custom_guardrail.py b/tests/guardrails_tests/test_custom_guardrail.py index 95ac82b8d0..af1270756f 100644 --- a/tests/guardrails_tests/test_custom_guardrail.py +++ b/tests/guardrails_tests/test_custom_guardrail.py @@ -6,6 +6,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/guardrails_tests/test_eu_ai_act_article5.py b/tests/guardrails_tests/test_eu_ai_act_article5.py index d602b206b5..bda7bf6f51 100644 --- a/tests/guardrails_tests/test_eu_ai_act_article5.py +++ b/tests/guardrails_tests/test_eu_ai_act_article5.py @@ -21,6 +21,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor ContentFilterCategoryConfig, ) + # Test cases: (sentence, expected_result, reason) TEST_CASES = [ # ALWAYS BLOCK - Explicit prohibited practices (1-10) diff --git a/tests/guardrails_tests/test_sg_mas_ai_guardrails.py b/tests/guardrails_tests/test_sg_mas_ai_guardrails.py index 47bfd1b8a1..668ee70469 100644 --- a/tests/guardrails_tests/test_sg_mas_ai_guardrails.py +++ b/tests/guardrails_tests/test_sg_mas_ai_guardrails.py @@ -23,6 +23,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor ContentFilterCategoryConfig, ) + # ── helpers ────────────────────────────────────────────────────────────── POLICY_DIR = os.path.abspath( diff --git a/tests/guardrails_tests/test_sg_pdpa_guardrails.py b/tests/guardrails_tests/test_sg_pdpa_guardrails.py index 03f55777f2..fd7133bc74 100644 --- a/tests/guardrails_tests/test_sg_pdpa_guardrails.py +++ b/tests/guardrails_tests/test_sg_pdpa_guardrails.py @@ -28,6 +28,7 @@ from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter impor ContentFilterCategoryConfig, ) + # ── helpers ────────────────────────────────────────────────────────────── POLICY_DIR = os.path.abspath( diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index bd39206fd0..5152e3e012 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -7,6 +7,7 @@ import sys import traceback from unittest.mock import AsyncMock, MagicMock, patch + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path diff --git a/tests/image_gen_tests/test_image_variation.py b/tests/image_gen_tests/test_image_variation.py index ef528c391e..301835057a 100644 --- a/tests/image_gen_tests/test_image_variation.py +++ b/tests/image_gen_tests/test_image_variation.py @@ -6,6 +6,7 @@ import os import sys import traceback + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path diff --git a/tests/litellm/llms/bedrock/embed/test_embedding.py b/tests/litellm/llms/bedrock/embed/test_embedding.py index 163c143d82..261448842f 100644 --- a/tests/litellm/llms/bedrock/embed/test_embedding.py +++ b/tests/litellm/llms/bedrock/embed/test_embedding.py @@ -11,6 +11,7 @@ import pytest from litellm.types.utils import Embedding from litellm.main import bedrock_embedding, embedding, EmbeddingResponse, Usage + _mock_model_id = ( "arn:aws:bedrock:us-east-1:123412341234:application-inference-profile/abc123123" ) diff --git a/tests/litellm/llms/bedrock/test_nova_imported_models.py b/tests/litellm/llms/bedrock/test_nova_imported_models.py index b5fd2630cb..e3677aaf9e 100644 --- a/tests/litellm/llms/bedrock/test_nova_imported_models.py +++ b/tests/litellm/llms/bedrock/test_nova_imported_models.py @@ -11,6 +11,7 @@ from litellm.llms.bedrock.common_utils import ( ) from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig + NOVA_ARN = "arn:aws:bedrock:us-east-1:123456789012:custom-model-deployment/a1b2c3d4e5f6" NOVA_MODEL = f"bedrock/nova/{NOVA_ARN}" NOVA2_MODEL = f"bedrock/nova-2/{NOVA_ARN}" diff --git a/tests/litellm/llms/vertex_ai/test_gemini_batch_embeddings.py b/tests/litellm/llms/vertex_ai/test_gemini_batch_embeddings.py index fde468618e..54ea41a645 100644 --- a/tests/litellm/llms/vertex_ai/test_gemini_batch_embeddings.py +++ b/tests/litellm/llms/vertex_ai/test_gemini_batch_embeddings.py @@ -581,21 +581,17 @@ def test_vertex_ai_text_only_embedding_uses_embed_content(): def test_filter_embed_params_drops_unsupported(): """Unsupported params like max_tokens should be filtered out.""" - result = _filter_embed_params( - {"dimensions": 768, "max_tokens": 256, "temperature": 0.5} - ) + result = _filter_embed_params({"dimensions": 768, "max_tokens": 256, "temperature": 0.5}) assert result == {"outputDimensionality": 768} def test_filter_embed_params_keeps_supported(): """All supported Gemini embedding params should pass through.""" - result = _filter_embed_params( - { - "dimensions": 768, - "task_type": "RETRIEVAL_DOCUMENT", - "title": "My doc", - } - ) + result = _filter_embed_params({ + "dimensions": 768, + "task_type": "RETRIEVAL_DOCUMENT", + "title": "My doc", + }) assert result == { "outputDimensionality": 768, "taskType": "RETRIEVAL_DOCUMENT", diff --git a/tests/litellm/test_bedrock_nemotron_super.py b/tests/litellm/test_bedrock_nemotron_super.py index 7b4efd86b2..8b081f10d1 100644 --- a/tests/litellm/test_bedrock_nemotron_super.py +++ b/tests/litellm/test_bedrock_nemotron_super.py @@ -11,6 +11,7 @@ import pytest from litellm import get_model_info + MODEL_NAME = "nvidia.nemotron-super-3-120b" diff --git a/tests/litellm_core_utils/test_bedrock_converse_dedup_factory.py b/tests/litellm_core_utils/test_bedrock_converse_dedup_factory.py index 87af6a35f4..c32917efe8 100644 --- a/tests/litellm_core_utils/test_bedrock_converse_dedup_factory.py +++ b/tests/litellm_core_utils/test_bedrock_converse_dedup_factory.py @@ -12,6 +12,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( BedrockConverseMessagesProcessor, ) + MODEL = "anthropic.claude-v2" PROVIDER = "bedrock_converse" @@ -545,9 +546,9 @@ def test_bedrock_converse_sorts_text_before_tooluse_sync(): tool_indices = [i for i, b in enumerate(content) if "toolUse" in b] # All text blocks must come before all toolUse blocks - assert max(text_indices) < min( - tool_indices - ), f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}" + assert max(text_indices) < min(tool_indices), ( + f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}" + ) @pytest.mark.asyncio @@ -566,9 +567,9 @@ async def test_bedrock_converse_sorts_text_before_tooluse_async(): text_indices = [i for i, b in enumerate(content) if "text" in b] tool_indices = [i for i, b in enumerate(content) if "toolUse" in b] - assert max(text_indices) < min( - tool_indices - ), f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}" + assert max(text_indices) < min(tool_indices), ( + f"text blocks at {text_indices} should all precede toolUse blocks at {tool_indices}" + ) @pytest.mark.asyncio @@ -576,9 +577,7 @@ async def test_bedrock_converse_content_ordering_sync_async_parity(): """Sync and async paths should produce identical content block ordering.""" messages = _make_tooluse_before_text_messages() sync_result = _bedrock_converse_messages_pt(messages, MODEL, PROVIDER) - async_result = ( - await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async( - messages, MODEL, PROVIDER - ) + async_result = await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async( + messages, MODEL, PROVIDER ) assert sync_result == async_result diff --git a/tests/litellm_utils_tests/test_aws_secret_manager.py b/tests/litellm_utils_tests/test_aws_secret_manager.py index 00417ea1b2..674f9b3ca8 100644 --- a/tests/litellm_utils_tests/test_aws_secret_manager.py +++ b/tests/litellm_utils_tests/test_aws_secret_manager.py @@ -10,6 +10,7 @@ from dotenv import load_dotenv import litellm.types import litellm.types.utils + load_dotenv() import io diff --git a/tests/litellm_utils_tests/test_litellm_overhead.py b/tests/litellm_utils_tests/test_litellm_overhead.py index d5aa7eb671..3a428e9d58 100644 --- a/tests/litellm_utils_tests/test_litellm_overhead.py +++ b/tests/litellm_utils_tests/test_litellm_overhead.py @@ -14,6 +14,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm + # Fake Vertex AI Gemini response for mocking FAKE_VERTEX_GEMINI_RESPONSE = { "candidates": [ diff --git a/tests/llm_responses_api_testing/test_anthropic_responses_api.py b/tests/llm_responses_api_testing/test_anthropic_responses_api.py index 746d54b4d4..6537f67acb 100644 --- a/tests/llm_responses_api_testing/test_anthropic_responses_api.py +++ b/tests/llm_responses_api_testing/test_anthropic_responses_api.py @@ -12,6 +12,7 @@ from litellm.responses.litellm_completion_transformation.transformation import ( ) from litellm.types.utils import ModelResponse + sys.path.insert(0, os.path.abspath("../..")) import litellm from litellm.integrations.custom_logger import CustomLogger diff --git a/tests/llm_responses_api_testing/test_anthropic_tool_result_empty_call_id.py b/tests/llm_responses_api_testing/test_anthropic_tool_result_empty_call_id.py index 3d256794cb..08b1c1784e 100644 --- a/tests/llm_responses_api_testing/test_anthropic_tool_result_empty_call_id.py +++ b/tests/llm_responses_api_testing/test_anthropic_tool_result_empty_call_id.py @@ -2,7 +2,7 @@ Test to reproduce and verify fix for Anthropic tool_result issue with empty call_id. This test reproduces the exact error: -"messages.0.content.0: unexpected `tool_use_id` found in `tool_result` blocks: tool_use_id. +"messages.0.content.0: unexpected `tool_use_id` found in `tool_result` blocks: tool_use_id. Each `tool_result` block must have a corresponding `tool_use` block in the previous message." The issue occurs when: diff --git a/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py b/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py index 8ce234dbcb..8f5278698b 100644 --- a/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py +++ b/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py @@ -2,12 +2,12 @@ Unit tests for BaseResponsesAPIStreamingIterator Tests core functionality including: -1. Processing chunks and handling ResponseCompletedEvent +1. Processing chunks and handling ResponseCompletedEvent 2. Ensuring _update_responses_api_response_id_with_model_id is called for final chunk 3. Verifying ID update is NOT called for non-final chunks (delta events) 4. Edge case handling for invalid JSON, empty chunks, and [DONE] markers -These tests ensure the streaming iterator correctly processes response chunks +These tests ensure the streaming iterator correctly processes response chunks and applies model ID updates only to completed responses, as required for proper response tracking and logging. """ diff --git a/tests/llm_responses_api_testing/test_responses_hooks.py b/tests/llm_responses_api_testing/test_responses_hooks.py index 09ac9c7a9e..3799a0b912 100644 --- a/tests/llm_responses_api_testing/test_responses_hooks.py +++ b/tests/llm_responses_api_testing/test_responses_hooks.py @@ -387,7 +387,9 @@ def test_process_chunk_completed_response_updates_id_and_usage_cost(monkeypatch) # Chunk must include a top-level "response" key so BaseResponsesAPIStreamingIterator # runs _update_responses_api_response_id_with_model_id (see streaming_iterator.py). event = iterator._process_chunk( - json.dumps({"type": "response.completed", "response": {"id": "resp_live"}}) + json.dumps( + {"type": "response.completed", "response": {"id": "resp_live"}} + ) ) finally: litellm.include_cost_in_streaming_usage = original_include_cost diff --git a/tests/llm_translation/test_bedrock_anthropic_regression.py b/tests/llm_translation/test_bedrock_anthropic_regression.py index d81e76cd4b..8b8ce0a6cc 100644 --- a/tests/llm_translation/test_bedrock_anthropic_regression.py +++ b/tests/llm_translation/test_bedrock_anthropic_regression.py @@ -22,8 +22,10 @@ sys.path.insert(0, os.path.abspath("../..")) import litellm from litellm import completion + # Large document for caching tests (needs 1024+ tokens for Claude models) -LARGE_DOCUMENT_FOR_CACHING = """ +LARGE_DOCUMENT_FOR_CACHING = ( + """ This is a comprehensive legal agreement between Party A and Party B. ARTICLE 1: DEFINITIONS @@ -75,7 +77,9 @@ ARTICLE 9: GENERAL PROVISIONS 9.5 Waiver of any provision shall not constitute ongoing waiver. IN WITNESS WHEREOF, the parties have executed this Agreement. -""" * 8 # Repeat to ensure we have enough tokens (need 1024+ for Claude models) +""" + * 8 +) # Repeat to ensure we have enough tokens (need 1024+ for Claude models) class TestBedrockAnthropicPromptCachingRegression: diff --git a/tests/llm_translation/test_bedrock_invoke_tests.py b/tests/llm_translation/test_bedrock_invoke_tests.py index 0258ab15c6..23f436d5b2 100644 --- a/tests/llm_translation/test_bedrock_invoke_tests.py +++ b/tests/llm_translation/test_bedrock_invoke_tests.py @@ -3,6 +3,7 @@ import pytest import sys import os + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path diff --git a/tests/llm_translation/test_crusoe.py b/tests/llm_translation/test_crusoe.py index 158152d8ae..56aa4e4cd4 100644 --- a/tests/llm_translation/test_crusoe.py +++ b/tests/llm_translation/test_crusoe.py @@ -1,7 +1,6 @@ """ Tests for Crusoe provider integration """ - import os from unittest import mock @@ -97,9 +96,9 @@ def test_crusoe_models_configuration(): for model in crusoe_models: model_info = get_model_info(model) assert model_info is not None, f"Model info not found for {model}" - assert ( - model_info.get("litellm_provider") == "crusoe" - ), f"{model} should have crusoe as provider" + assert model_info.get("litellm_provider") == "crusoe", ( + f"{model} should have crusoe as provider" + ) assert model_info.get("mode") == "chat", f"{model} should be in chat mode" finally: litellm.model_cost = original_model_cost diff --git a/tests/llm_translation/test_gemini.py b/tests/llm_translation/test_gemini.py index 3e322e1a6b..97b0aaee86 100644 --- a/tests/llm_translation/test_gemini.py +++ b/tests/llm_translation/test_gemini.py @@ -1362,12 +1362,8 @@ def test_anthropic_thinking_param_to_gemini_3_provider_defaults(): ) # For Gemini 3, should not force thinkingLevel by default - assert ( - "thinkingLevel" not in result - ), "Should not force thinkingLevel for Gemini 3" - assert ( - "thinkingBudget" not in result - ), "Should NOT have thinkingBudget for Gemini 3" + assert "thinkingLevel" not in result, "Should not force thinkingLevel for Gemini 3" + assert "thinkingBudget" not in result, "Should NOT have thinkingBudget for Gemini 3" assert result["includeThoughts"] is True # Test 2: Anthropic thinking disabled for Gemini 3 @@ -1399,10 +1395,7 @@ def test_anthropic_thinking_param_to_gemini_3_provider_defaults(): ) assert result_zero["includeThoughts"] is False - assert ( - "thinkingLevel" not in result_zero - or result_zero.get("thinkingLevel") is None - ) + assert "thinkingLevel" not in result_zero or result_zero.get("thinkingLevel") is None # Test 4: Gemini 3 flash-preview should also follow provider defaults by default result_gemini3flashpreview = VertexGeminiConfig._map_thinking_param( @@ -1532,12 +1525,8 @@ def test_anthropic_thinking_param_via_map_openai_params(): # Check that thinkingConfig was created without forced thinkingLevel assert "thinkingConfig" in result, "Should have thinkingConfig in optional_params" thinking_config = result["thinkingConfig"] - assert ( - "thinkingLevel" not in thinking_config - ), "Should not force thinkingLevel for Gemini 3 by default" - assert ( - "thinkingBudget" not in thinking_config - ), "Should NOT have thinkingBudget for Gemini 3" + assert "thinkingLevel" not in thinking_config, "Should not force thinkingLevel for Gemini 3 by default" + assert "thinkingBudget" not in thinking_config, "Should NOT have thinkingBudget for Gemini 3" assert thinking_config["includeThoughts"] is True # Test with Gemini 2 model diff --git a/tests/llm_translation/test_gemini_image_usage.py b/tests/llm_translation/test_gemini_image_usage.py index 75ce8c8146..096f9c4796 100644 --- a/tests/llm_translation/test_gemini_image_usage.py +++ b/tests/llm_translation/test_gemini_image_usage.py @@ -1,7 +1,7 @@ """ Test for Gemini image generation usage metadata extraction. -This test verifies the fix for issue #18323 where image_generation() +This test verifies the fix for issue #18323 where image_generation() was returning usage=0 while completion() returned proper token usage. """ diff --git a/tests/local_testing/create_mock_standard_logging_payload.py b/tests/local_testing/create_mock_standard_logging_payload.py index 828faba0b0..2fd6a4ffa8 100644 --- a/tests/local_testing/create_mock_standard_logging_payload.py +++ b/tests/local_testing/create_mock_standard_logging_payload.py @@ -2,6 +2,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 8ca13369ff..9782bf3c2a 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -38,6 +38,7 @@ from litellm.llms.vertex_ai.gemini.transformation import ( ) from litellm.llms.vertex_ai.vertex_llm_base import VertexBase + litellm.num_retries = 3 litellm.cache = None user_message = "Write a short poem about the sky" @@ -1103,7 +1104,9 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs): { "content": { "role": "model", - "parts": [{"text": """{ + "parts": [ + { + "text": """{ "recipes": [ {"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, @@ -1111,7 +1114,9 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs): {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"} ] - }"""}], + }""" + } + ], }, "finishReason": "STOP", "safetyRatings": [ diff --git a/tests/local_testing/test_cache_preset_key.py b/tests/local_testing/test_cache_preset_key.py index cf496fe4b8..d6518c5a07 100644 --- a/tests/local_testing/test_cache_preset_key.py +++ b/tests/local_testing/test_cache_preset_key.py @@ -4,7 +4,7 @@ Test for preset_cache_key multiple values bug fix. This test verifies that get_cache_key doesn't raise TypeError when kwargs already contains preset_cache_key. -Issue: When get_cache_key(**kwargs) is called with kwargs containing +Issue: When get_cache_key(**kwargs) is called with kwargs containing preset_cache_key, the call to _set_preset_cache_key_in_kwargs() would fail with: TypeError: got multiple values for keyword argument 'preset_cache_key' """ diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index a68e8c9158..0c7c015765 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -450,9 +450,12 @@ def c(): litellm.enable_caching_on_provider_specific_optional_params = False -embedding_large_text = """ +embedding_large_text = ( + """ small text -""" * 5 +""" + * 5 +) # # test_caching_with_models() diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index df1b6d7a8e..6341fa7800 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -1638,13 +1638,15 @@ def custom_callback( ################################################# - print(f""" + print( + f""" Model: {model}, Messages: {messages}, User: {user}, Seed: {kwargs["seed"]}, temperature: {kwargs["temperature"]}, - """) + """ + ) assert kwargs["user"] == "ishaans app" assert kwargs["model"] == "gpt-3.5-turbo-1106" diff --git a/tests/local_testing/test_configs/custom_callbacks.py b/tests/local_testing/test_configs/custom_callbacks.py index 7ef6c1aaad..42f88b5d19 100644 --- a/tests/local_testing/test_configs/custom_callbacks.py +++ b/tests/local_testing/test_configs/custom_callbacks.py @@ -93,7 +93,8 @@ class testCustomCallbackProxy(CustomLogger): print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger)) - print(f""" + print( + f""" Model: {model}, Messages: {messages}, User: {user}, @@ -101,7 +102,8 @@ class testCustomCallbackProxy(CustomLogger): Cost: {cost}, Response: {response} Proxy Metadata: {metadata} - """) + """ + ) return async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): diff --git a/tests/local_testing/test_get_llm_provider.py b/tests/local_testing/test_get_llm_provider.py index 010a071f73..14b9e8cd13 100644 --- a/tests/local_testing/test_get_llm_provider.py +++ b/tests/local_testing/test_get_llm_provider.py @@ -477,3 +477,4 @@ def test_get_llm_provider_use_proxy_arg_true_with_direct_args(): assert provider == "litellm_proxy" assert key == arg_api_key # Should use the argument key assert base == arg_api_base # Should use the argument base + diff --git a/tests/local_testing/test_langsmith.py b/tests/local_testing/test_langsmith.py index 6f09b2701b..af7ac46a1c 100644 --- a/tests/local_testing/test_langsmith.py +++ b/tests/local_testing/test_langsmith.py @@ -21,6 +21,7 @@ verbose_logger.setLevel(logging.DEBUG) litellm.set_verbose = True import time + # test_langsmith_logging() diff --git a/tests/local_testing/test_router_caching.py b/tests/local_testing/test_router_caching.py index ad358699a4..cb223b661b 100644 --- a/tests/local_testing/test_router_caching.py +++ b/tests/local_testing/test_router_caching.py @@ -16,6 +16,7 @@ import litellm from litellm import Router from litellm.caching import RedisCache, RedisClusterCache + ## Scenarios ## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo", ## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group diff --git a/tests/logging_callback_tests/create_mock_standard_logging_payload.py b/tests/logging_callback_tests/create_mock_standard_logging_payload.py index 828faba0b0..2fd6a4ffa8 100644 --- a/tests/logging_callback_tests/create_mock_standard_logging_payload.py +++ b/tests/logging_callback_tests/create_mock_standard_logging_payload.py @@ -2,6 +2,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py b/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py index ce2b05665c..3e8d59b299 100644 --- a/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py +++ b/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py @@ -2,6 +2,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/logging_callback_tests/test_datadog_llm_obs.py b/tests/logging_callback_tests/test_datadog_llm_obs.py index b192288dae..74f642e6fa 100644 --- a/tests/logging_callback_tests/test_datadog_llm_obs.py +++ b/tests/logging_callback_tests/test_datadog_llm_obs.py @@ -6,6 +6,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/logging_callback_tests/test_gcs_pub_sub.py b/tests/logging_callback_tests/test_gcs_pub_sub.py index 6c99e724af..f4cc973517 100644 --- a/tests/logging_callback_tests/test_gcs_pub_sub.py +++ b/tests/logging_callback_tests/test_gcs_pub_sub.py @@ -2,6 +2,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/logging_callback_tests/test_generic_api_callback.py b/tests/logging_callback_tests/test_generic_api_callback.py index a2fdf0b2b0..6984b6fa00 100644 --- a/tests/logging_callback_tests/test_generic_api_callback.py +++ b/tests/logging_callback_tests/test_generic_api_callback.py @@ -2,6 +2,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/logging_callback_tests/test_langsmith_unit_test.py b/tests/logging_callback_tests/test_langsmith_unit_test.py index d408c1caef..155b1f396f 100644 --- a/tests/logging_callback_tests/test_langsmith_unit_test.py +++ b/tests/logging_callback_tests/test_langsmith_unit_test.py @@ -2,6 +2,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/logging_callback_tests/test_log_db_redis_services.py b/tests/logging_callback_tests/test_log_db_redis_services.py index 08e45590a8..fa0c3b595a 100644 --- a/tests/logging_callback_tests/test_log_db_redis_services.py +++ b/tests/logging_callback_tests/test_log_db_redis_services.py @@ -2,6 +2,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/logging_callback_tests/test_view_request_resp_logs.py b/tests/logging_callback_tests/test_view_request_resp_logs.py index 66463e315c..ea778a44e6 100644 --- a/tests/logging_callback_tests/test_view_request_resp_logs.py +++ b/tests/logging_callback_tests/test_view_request_resp_logs.py @@ -25,6 +25,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket import ( ) from litellm.types.utils import StandardCallbackDynamicParams + # This is the response payload that GCS would return. mock_response_data = { "id": "chatcmpl-9870a859d6df402795f75dc5fca5b2e0", diff --git a/tests/mcp_tests/test_mcp_logging.py b/tests/mcp_tests/test_mcp_logging.py index 73de99ac89..55b49aa0d2 100644 --- a/tests/mcp_tests/test_mcp_logging.py +++ b/tests/mcp_tests/test_mcp_logging.py @@ -5,6 +5,7 @@ import asyncio from typing import Optional from unittest.mock import AsyncMock, patch + sys.path.insert( 0, os.path.abspath("../../..") ) # Adds the parent directory to the system path diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index e931ba4ab9..6af0758579 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -18,6 +18,7 @@ from litellm.proxy._types import LiteLLM_ObjectPermissionTable from mcp.types import Tool as MCPTool, CallToolResult, ListToolsResult from mcp.types import TextContent + mcp_server_manager = MCPServerManager() diff --git a/tests/mcp_tests/test_proxy_mcp_e2e.py b/tests/mcp_tests/test_proxy_mcp_e2e.py index 62fc7ad3ff..2dd57e13d3 100644 --- a/tests/mcp_tests/test_proxy_mcp_e2e.py +++ b/tests/mcp_tests/test_proxy_mcp_e2e.py @@ -20,6 +20,7 @@ from litellm.proxy.proxy_server import ( initialize, ) + CONFIG_TEMPLATE_PATH = Path("tests/mcp_tests/test_configs/test_config_mcp_e2e.yaml") MCP_SERVER_SCRIPT = Path("tests/mcp_tests/mcp_server.py") PROJECT_ROOT = Path(__file__).resolve().parents[2] diff --git a/tests/ocr_tests/base_ocr_unit_tests.py b/tests/ocr_tests/base_ocr_unit_tests.py index a09158ce18..ae65efd952 100644 --- a/tests/ocr_tests/base_ocr_unit_tests.py +++ b/tests/ocr_tests/base_ocr_unit_tests.py @@ -9,6 +9,7 @@ import litellm import os from abc import ABC, abstractmethod + # Test resources TEST_IMAGE_PATH = "test_image_edit.png" # Tiny in-repo PDF served via jsdelivr (sha-pinned, immutable). The arxiv diff --git a/tests/ocr_tests/test_ocr_azure_document_intelligence.py b/tests/ocr_tests/test_ocr_azure_document_intelligence.py index 27428f96a2..7269890b7b 100644 --- a/tests/ocr_tests/test_ocr_azure_document_intelligence.py +++ b/tests/ocr_tests/test_ocr_azure_document_intelligence.py @@ -110,7 +110,9 @@ class TestAzureDocumentIntelligencePagesParam: model="azure_ai/doc-intelligence/prebuilt-layout", optional_params={"pages": "1-3,5"}, ) - assert f"api-version={AZURE_DOCUMENT_INTELLIGENCE_API_VERSION}" in url, url + assert ( + f"api-version={AZURE_DOCUMENT_INTELLIGENCE_API_VERSION}" in url + ), url assert "pages=1-3,5" in url, url assert "/documentintelligence/documentModels/prebuilt-layout:analyze" in url @@ -166,3 +168,4 @@ class TestAzureDocumentIntelligencePagesParam: assert "pages=3,4,5,6,7,8,9" in url assert req.data == {"urlSource": "https://example.com/x.pdf"} + diff --git a/tests/old_proxy_tests/tests/bursty_load_test_completion.py b/tests/old_proxy_tests/tests/bursty_load_test_completion.py index 41944c03aa..642bd9f14d 100644 --- a/tests/old_proxy_tests/tests/bursty_load_test_completion.py +++ b/tests/old_proxy_tests/tests/bursty_load_test_completion.py @@ -3,6 +3,7 @@ from openai import AsyncOpenAI from litellm._uuid import uuid import traceback + litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000") diff --git a/tests/old_proxy_tests/tests/load_test_embedding_100.py b/tests/old_proxy_tests/tests/load_test_embedding_100.py index bfb4e137e8..8cd4d25024 100644 --- a/tests/old_proxy_tests/tests/load_test_embedding_100.py +++ b/tests/old_proxy_tests/tests/load_test_embedding_100.py @@ -3,6 +3,7 @@ from openai import AsyncOpenAI from litellm._uuid import uuid import traceback + litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000") diff --git a/tests/old_proxy_tests/tests/test_openai_request_with_traceparent.py b/tests/old_proxy_tests/tests/test_openai_request_with_traceparent.py index cde68002a7..2f8455dcbe 100644 --- a/tests/old_proxy_tests/tests/test_openai_request_with_traceparent.py +++ b/tests/old_proxy_tests/tests/test_openai_request_with_traceparent.py @@ -8,6 +8,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + trace.set_tracer_provider(TracerProvider()) memory_exporter = InMemorySpanExporter() span_processor = SimpleSpanProcessor(memory_exporter) diff --git a/tests/openai_endpoints_tests/test_openai_batches_endpoint.py b/tests/openai_endpoints_tests/test_openai_batches_endpoint.py index c75416edf1..ad28e8da3d 100644 --- a/tests/openai_endpoints_tests/test_openai_batches_endpoint.py +++ b/tests/openai_endpoints_tests/test_openai_batches_endpoint.py @@ -11,6 +11,7 @@ import sys import time from unittest.mock import patch, MagicMock, AsyncMock + BASE_URL = "http://localhost:4000" # Replace with your actual base URL API_KEY = "sk-1234" # Replace with your actual API key diff --git a/tests/openai_endpoints_tests/test_openai_files_endpoints.py b/tests/openai_endpoints_tests/test_openai_files_endpoints.py index 547bb55483..6be692b278 100644 --- a/tests/openai_endpoints_tests/test_openai_files_endpoints.py +++ b/tests/openai_endpoints_tests/test_openai_files_endpoints.py @@ -6,6 +6,7 @@ import aiohttp, openai from openai import OpenAI, AsyncOpenAI from typing import Optional, List, Union + BASE_URL = "http://localhost:4000" # Replace with your actual base URL API_KEY = "sk-1234" # Replace with your actual API key diff --git a/tests/otel_tests/test_e2e_model_access.py b/tests/otel_tests/test_e2e_model_access.py index 485031d4e9..7ea75a9d61 100644 --- a/tests/otel_tests/test_e2e_model_access.py +++ b/tests/otel_tests/test_e2e_model_access.py @@ -5,6 +5,7 @@ import json from httpx import AsyncClient from typing import Any, Optional, List, Literal + # The proxy strips client-supplied `mock_response` unless the calling key or # team has this admin-metadata flag set. See `_UNTRUSTED_ROOT_CONTROL_FIELDS` # in litellm/proxy/litellm_pre_call_utils.py. diff --git a/tests/otel_tests/test_team_member_permissions.py b/tests/otel_tests/test_team_member_permissions.py index 36926b9f52..ddb8b741c4 100644 --- a/tests/otel_tests/test_team_member_permissions.py +++ b/tests/otel_tests/test_team_member_permissions.py @@ -4,13 +4,13 @@ Invalid Permissions: - - User tries creating a key with team_id = team_id -> expect to fail. Invalid Permissions - - User tries editing a key with team_id = team_id -> expect to fail. Invalid Permissions - - User tries deleting a key with team_id = team_id -> expect to fail. Invalid Permissions - - User tries regenerating a key with team_id = team_id -> expect to fail. Invalid Permissions + - User tries creating a key with team_id = team_id -> expect to fail. Invalid Permissions + - User tries editing a key with team_id = team_id -> expect to fail. Invalid Permissions + - User tries deleting a key with team_id = team_id -> expect to fail. Invalid Permissions + - User tries regenerating a key with team_id = team_id -> expect to fail. Invalid Permissions Valid Permissions: - - User tries calling /key/info with team_id, expect to get valid response + - User tries calling /key/info with team_id, expect to get valid response @@ -26,7 +26,7 @@ Invalid Permissions: - User tries creating a key with team_id = team_id -> expect to fail. Invalid Permissions - - User tries calling /key/info with team_id, expect to get valid response + - User tries calling /key/info with team_id, expect to get valid response diff --git a/tests/pass_through_tests/test_openai_assistants_passthrough.py b/tests/pass_through_tests/test_openai_assistants_passthrough.py index c8e9a9ef0f..28568005fd 100644 --- a/tests/pass_through_tests/test_openai_assistants_passthrough.py +++ b/tests/pass_through_tests/test_openai_assistants_passthrough.py @@ -6,6 +6,7 @@ import tempfile from typing_extensions import override from openai import AssistantEventHandler + client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234") diff --git a/tests/pass_through_tests/test_vertex_ai.py b/tests/pass_through_tests/test_vertex_ai.py index 668ca017f7..73bf03c500 100644 --- a/tests/pass_through_tests/test_vertex_ai.py +++ b/tests/pass_through_tests/test_vertex_ai.py @@ -12,6 +12,7 @@ import os import pytest import asyncio + # Path to your service account JSON file SERVICE_ACCOUNT_FILE = "path/to/your/service-account.json" diff --git a/tests/pass_through_unit_tests/base_anthropic_messages_prompt_caching_test.py b/tests/pass_through_unit_tests/base_anthropic_messages_prompt_caching_test.py index b218776239..5fc4ecefb3 100644 --- a/tests/pass_through_unit_tests/base_anthropic_messages_prompt_caching_test.py +++ b/tests/pass_through_unit_tests/base_anthropic_messages_prompt_caching_test.py @@ -22,8 +22,10 @@ sys.path.insert(0, os.path.abspath("../../..")) import pytest import litellm + # Large document for caching tests (needs 1024+ tokens for Claude models) -LARGE_DOCUMENT_FOR_CACHING = """ +LARGE_DOCUMENT_FOR_CACHING = ( + """ This is a comprehensive legal agreement between Party A and Party B. ARTICLE 1: DEFINITIONS @@ -75,7 +77,9 @@ ARTICLE 9: GENERAL PROVISIONS 9.5 Waiver of any provision shall not constitute ongoing waiver. IN WITNESS WHEREOF, the parties have executed this Agreement. -""" * 8 # Repeat to ensure we have enough tokens (need 1024+ for Claude models) +""" + * 8 +) # Repeat to ensure we have enough tokens (need 1024+ for Claude models) class BaseAnthropicMessagesPromptCachingTest(ABC): diff --git a/tests/proxy_admin_ui_tests/test_route_check_unit_tests.py b/tests/proxy_admin_ui_tests/test_route_check_unit_tests.py index 3b1db1c327..f0cc6985e6 100644 --- a/tests/proxy_admin_ui_tests/test_route_check_unit_tests.py +++ b/tests/proxy_admin_ui_tests/test_route_check_unit_tests.py @@ -14,6 +14,7 @@ import io import os import time + # this file is to test litellm/proxy sys.path.insert( diff --git a/tests/proxy_admin_ui_tests/test_usage_endpoints.py b/tests/proxy_admin_ui_tests/test_usage_endpoints.py index d8ef87f933..54ad136f08 100644 --- a/tests/proxy_admin_ui_tests/test_usage_endpoints.py +++ b/tests/proxy_admin_ui_tests/test_usage_endpoints.py @@ -1,5 +1,5 @@ """ -Tests the following endpoints used by the UI +Tests the following endpoints used by the UI /global/spend/logs /global/spend/keys @@ -9,7 +9,7 @@ Tests the following endpoints used by the UI For all tests - test the following: -- Response is valid +- Response is valid - Response for Admin User is different from response from Internal User """ diff --git a/tests/proxy_e2e_anthropic_messages_tests/test_claude_agent_sdk.py b/tests/proxy_e2e_anthropic_messages_tests/test_claude_agent_sdk.py index 7f5fcd38c8..48eb7d85ec 100644 --- a/tests/proxy_e2e_anthropic_messages_tests/test_claude_agent_sdk.py +++ b/tests/proxy_e2e_anthropic_messages_tests/test_claude_agent_sdk.py @@ -12,6 +12,7 @@ import pytest import asyncio from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions + # Test models from test_config.yaml # Note: bedrock-converse-claude-sonnet-4.5 removed temporarily as the Bedrock Converse API # for Claude Sonnet 4.5 may not be available in all regions/accounts diff --git a/tests/proxy_unit_tests/conftest.py b/tests/proxy_unit_tests/conftest.py index 6ff9ffe84c..a0326f64ed 100644 --- a/tests/proxy_unit_tests/conftest.py +++ b/tests/proxy_unit_tests/conftest.py @@ -16,6 +16,7 @@ sys.path.insert( import litellm import litellm.proxy.proxy_server + # Top-level assignments of these types are the ones importlib.reload(litellm) # would have effectively reset. We snapshot them at conftest import time and # deep-copy the snapshot back before every test. diff --git a/tests/proxy_unit_tests/test_configs/custom_callbacks.py b/tests/proxy_unit_tests/test_configs/custom_callbacks.py index ac705cd28a..c7d66c068e 100644 --- a/tests/proxy_unit_tests/test_configs/custom_callbacks.py +++ b/tests/proxy_unit_tests/test_configs/custom_callbacks.py @@ -93,7 +93,8 @@ class testCustomCallbackProxy(CustomLogger): print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger)) - print(f""" + print( + f""" Model: {model}, Messages: {messages}, User: {user}, @@ -101,7 +102,8 @@ class testCustomCallbackProxy(CustomLogger): Cost: {cost}, Response: {response} Proxy Metadata: {metadata} - """) + """ + ) return async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): diff --git a/tests/proxy_unit_tests/test_jwt_key_mapping.py b/tests/proxy_unit_tests/test_jwt_key_mapping.py index 8e3331d17f..bf1c4a3f6c 100644 --- a/tests/proxy_unit_tests/test_jwt_key_mapping.py +++ b/tests/proxy_unit_tests/test_jwt_key_mapping.py @@ -27,6 +27,7 @@ from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import ( from litellm.caching.caching import DualCache from fastapi import HTTPException + # ────────────────────────────────────────────── # Tests: _resolve_jwt_to_virtual_key # ────────────────────────────────────────────── diff --git a/tests/proxy_unit_tests/test_search_api_logging.py b/tests/proxy_unit_tests/test_search_api_logging.py index 3298389083..71bbe5351a 100644 --- a/tests/proxy_unit_tests/test_search_api_logging.py +++ b/tests/proxy_unit_tests/test_search_api_logging.py @@ -2,7 +2,7 @@ Test search API logging and cost tracking in proxy. Tests that search API requests are properly logged to LiteLLM_SpendLogs -with correct fields populated (call_type, model, custom_llm_provider, +with correct fields populated (call_type, model, custom_llm_provider, model_group, spend, etc.) """ diff --git a/tests/router_unit_tests/create_mock_standard_logging_payload.py b/tests/router_unit_tests/create_mock_standard_logging_payload.py index 828faba0b0..2fd6a4ffa8 100644 --- a/tests/router_unit_tests/create_mock_standard_logging_payload.py +++ b/tests/router_unit_tests/create_mock_standard_logging_payload.py @@ -2,6 +2,7 @@ import io import os import sys + sys.path.insert(0, os.path.abspath("../..")) import asyncio diff --git a/tests/router_unit_tests/test_router_handle_error.py b/tests/router_unit_tests/test_router_handle_error.py index 552b004883..660b388512 100644 --- a/tests/router_unit_tests/test_router_handle_error.py +++ b/tests/router_unit_tests/test_router_handle_error.py @@ -15,6 +15,7 @@ from collections import defaultdict from dotenv import load_dotenv from unittest.mock import AsyncMock, MagicMock + load_dotenv() diff --git a/tests/search_tests/test_google_pse_search.py b/tests/search_tests/test_google_pse_search.py index 9ce55b428d..21d58a9549 100644 --- a/tests/search_tests/test_google_pse_search.py +++ b/tests/search_tests/test_google_pse_search.py @@ -10,6 +10,7 @@ sys.path.insert(0, os.path.abspath("../..")) from tests.search_tests.base_search_unit_tests import BaseSearchTest + # class TestGooglePSESearch(BaseSearchTest): # """ # Tests for Google PSE Search functionality. diff --git a/tests/store_model_in_db_tests/test_callbacks_in_db.py b/tests/store_model_in_db_tests/test_callbacks_in_db.py index 51545333c6..4a851251a3 100644 --- a/tests/store_model_in_db_tests/test_callbacks_in_db.py +++ b/tests/store_model_in_db_tests/test_callbacks_in_db.py @@ -1,9 +1,9 @@ """ PROD TEST - DO NOT Delete this Test -e2e test for langfuse callback in DB +e2e test for langfuse callback in DB - Add langfuse callback to DB - with /config/update -- wait 20 seconds for the callback to be loaded into the instance +- wait 20 seconds for the callback to be loaded into the instance - Make a /chat/completions request to the proxy - Check if the request is logged in Langfuse """ diff --git a/tests/test_litellm/a2a_protocol/providers/bedrock_agentcore/test_bedrock_agentcore_a2a.py b/tests/test_litellm/a2a_protocol/providers/bedrock_agentcore/test_bedrock_agentcore_a2a.py index 694806c179..a4f7f8187c 100644 --- a/tests/test_litellm/a2a_protocol/providers/bedrock_agentcore/test_bedrock_agentcore_a2a.py +++ b/tests/test_litellm/a2a_protocol/providers/bedrock_agentcore/test_bedrock_agentcore_a2a.py @@ -14,6 +14,7 @@ import json import pytest from unittest.mock import AsyncMock, MagicMock, patch + SAMPLE_ARN = "arn:aws:bedrock-agentcore:us-west-2:123456789:runtime/my_agent" SAMPLE_MODEL = f"bedrock/agentcore/{SAMPLE_ARN}" SAMPLE_PARAMS = { diff --git a/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py b/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py index 23a5d9ec99..af5e234140 100644 --- a/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py +++ b/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py @@ -30,7 +30,6 @@ def no_invitation_wait(monkeypatch): monkeypatch.setattr(BaseEmailLogger, "_wait_for_invitation_creation", _noop) - @pytest.fixture def base_email_logger(): return BaseEmailLogger() @@ -284,10 +283,7 @@ async def test_send_key_created_email_without_key( mock_send_email.assert_called_once() call_args = mock_send_email.call_args[1] assert "sk-secret-key-456" not in call_args["html_body"] - assert ( - "[Key hidden for security - retrieve from dashboard]" - in call_args["html_body"] - ) + assert "[Key hidden for security - retrieve from dashboard]" in call_args["html_body"] @pytest.mark.asyncio @@ -321,10 +317,7 @@ async def test_send_key_rotated_email_without_key( mock_send_email.assert_called_once() call_args = mock_send_email.call_args[1] assert "sk-secret-rotated-789" not in call_args["html_body"] - assert ( - "[Key hidden for security - retrieve from dashboard]" - in call_args["html_body"] - ) + assert "[Key hidden for security - retrieve from dashboard]" in call_args["html_body"] @pytest.mark.asyncio @@ -378,52 +371,52 @@ async def test_get_invitation_link_creates_new_when_none_exist(base_email_logger """Test that _get_invitation_link creates a new invitation when none exist""" # Mock prisma client with no existing invitation rows mock_prisma = mock.MagicMock() - + # Mock find_many to return empty list (no existing invitations) async def mock_find_many_empty(*args, **kwargs): return [] - + mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_empty - + # Mock the create_invitation_for_user function mock_created_invitation = mock.MagicMock() mock_created_invitation.id = "new-invitation-id" - + with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): with mock.patch( "litellm.proxy.management_helpers.user_invitation.create_invitation_for_user", - return_value=mock_created_invitation, + return_value=mock_created_invitation ) as mock_create_invitation: # Execute result = await base_email_logger._get_invitation_link( user_id="test-user", base_url="http://test.com" ) - + # Verify that create_invitation_for_user was called mock_create_invitation.assert_called_once() call_args = mock_create_invitation.call_args[1] assert call_args["data"].user_id == "test-user" assert call_args["user_api_key_dict"].user_id == "test-user" - + # Verify the returned link uses the new invitation ID assert result == "http://test.com/ui?invitation_id=new-invitation-id" -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_get_invitation_link_uses_existing_when_available(base_email_logger): """Test that _get_invitation_link uses existing invitation when available""" # Mock prisma client with existing invitation row mock_invitation_row = mock.MagicMock() mock_invitation_row.id = "existing-invitation-id" - + mock_prisma = mock.MagicMock() - + # Mock find_many to return existing invitation async def mock_find_many_existing(*args, **kwargs): return [mock_invitation_row] - + mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_existing - + with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): with mock.patch( "litellm.proxy.management_helpers.user_invitation.create_invitation_for_user" @@ -432,10 +425,10 @@ async def test_get_invitation_link_uses_existing_when_available(base_email_logge result = await base_email_logger._get_invitation_link( user_id="test-user", base_url="http://test.com" ) - + # Verify that create_invitation_for_user was NOT called mock_create_invitation.assert_not_called() - + # Verify the returned link uses the existing invitation ID assert result == "http://test.com/ui?invitation_id=existing-invitation-id" @@ -445,33 +438,33 @@ async def test_get_invitation_link_creates_new_when_list_is_none(base_email_logg """Test that _get_invitation_link creates a new invitation when invitation_rows is None""" # Mock prisma client to return None mock_prisma = mock.MagicMock() - + # Mock find_many to return None async def mock_find_many_none(*args, **kwargs): return None - + mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_none - + # Mock the create_invitation_for_user function mock_created_invitation = mock.MagicMock() mock_created_invitation.id = "new-invitation-from-none" - + with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): with mock.patch( "litellm.proxy.management_helpers.user_invitation.create_invitation_for_user", - return_value=mock_created_invitation, + return_value=mock_created_invitation ) as mock_create_invitation: # Execute result = await base_email_logger._get_invitation_link( user_id="test-user", base_url="http://test.com" ) - + # Verify that create_invitation_for_user was called mock_create_invitation.assert_called_once() call_args = mock_create_invitation.call_args[1] assert call_args["data"].user_id == "test-user" assert call_args["user_api_key_dict"].user_id == "test-user" - + # Verify the returned link uses the new invitation ID assert result == "http://test.com/ui?invitation_id=new-invitation-from-none" @@ -502,15 +495,13 @@ async def test_get_email_params_user_invitation( user_email="test@example.com", ) - assert ( - result.logo_url - == "https://litellm-listing.s3.amazonaws.com/litellm_logo.png" - ) + assert result.logo_url == "https://litellm-listing.s3.amazonaws.com/litellm_logo.png" assert result.support_contact == "support@berri.ai" assert result.base_url == "http://test.com/ui?invitation_id=test-id" assert result.recipient_email == "test@example.com" + @pytest.fixture def mock_env_vars(monkeypatch): """Set up test environment variables""" @@ -522,74 +513,69 @@ def mock_env_vars(monkeypatch): monkeypatch.setenv("PROXY_BASE_URL", "http://test.com") monkeypatch.setenv("PROXY_API_URL", "https://test.com") - @pytest.mark.asyncio async def test_get_email_params_custom_templates_premium_user(mock_env_vars): """Test that _get_email_params returns correct values with custom templates for premium users""" # Mock premium_user as True with patch("litellm.proxy.proxy_server.premium_user", True): email_logger = BaseEmailLogger() - + # Test invitation email params invitation_params = await email_logger._get_email_params( email_event=EmailEvent.new_user_invitation, user_id="testid", user_email="test@example.com", - event_message="New User Invitation", + event_message="New User Invitation" ) - + assert invitation_params.subject == "Welcome to Test Company!" assert invitation_params.signature == "Best regards,\nTest Company Team" assert invitation_params.logo_url == "https://test-company.com/logo.png" assert invitation_params.support_contact == "support@test-company.com" assert invitation_params.base_url == "http://test.com" - + # Test key created email params key_params = await email_logger._get_email_params( email_event=EmailEvent.virtual_key_created, user_id="testid", user_email="test@example.com", - event_message="API Key Created", + event_message="API Key Created" ) - + assert key_params.subject == "Your Test Company API Key" assert key_params.signature == "Best regards,\nTest Company Team" - @pytest.mark.asyncio async def test_get_email_params_non_premium_user(mock_env_vars): """Test that non-premium users get default templates even when custom ones are provided""" # Mock premium_user as False with patch("litellm.proxy.proxy_server.premium_user", False): email_logger = BaseEmailLogger() - + # Test invitation email params email_params = await email_logger._get_email_params( email_event=EmailEvent.new_user_invitation, user_email="test@example.com", - event_message="New User Invitation", + event_message="New User Invitation" ) - + # Should use default values even though custom values are set in env assert email_params.subject == "LiteLLM: New User Invitation" assert email_params.signature == EMAIL_FOOTER - assert ( - email_params.logo_url - == "https://litellm-listing.s3.amazonaws.com/litellm_logo.png" - ) + assert email_params.logo_url == "https://litellm-listing.s3.amazonaws.com/litellm_logo.png" assert email_params.support_contact == "support@berri.ai" + # Test key created email params key_params = await email_logger._get_email_params( email_event=EmailEvent.virtual_key_created, user_email="test@example.com", - event_message="API Key Created", + event_message="API Key Created" ) - + assert key_params.subject == "LiteLLM: API Key Created" assert key_params.signature == EMAIL_FOOTER - @pytest.mark.asyncio async def test_get_email_params_default_templates(monkeypatch): """Test that _get_email_params uses default templates when custom ones aren't provided""" @@ -597,28 +583,28 @@ async def test_get_email_params_default_templates(monkeypatch): monkeypatch.delenv("EMAIL_SUBJECT_INVITATION", raising=False) monkeypatch.delenv("EMAIL_SUBJECT_KEY_CREATED", raising=False) monkeypatch.delenv("EMAIL_SIGNATURE", raising=False) - + # Mock premium_user as True (shouldn't matter since no custom values are set) with patch("litellm.proxy.proxy_server.premium_user", True): email_logger = BaseEmailLogger() - + # Test invitation email params with default template invitation_params = await email_logger._get_email_params( email_event=EmailEvent.new_user_invitation, user_email="test@example.com", - event_message="New User Invitation", + event_message="New User Invitation" ) - + assert invitation_params.subject == "LiteLLM: New User Invitation" assert invitation_params.signature == EMAIL_FOOTER - + # Test key created email params with default template key_params = await email_logger._get_email_params( email_event=EmailEvent.virtual_key_created, user_email="test@example.com", - event_message="API Key Created", + event_message="API Key Created" ) - + assert key_params.subject == "LiteLLM: API Key Created" assert key_params.signature == EMAIL_FOOTER @@ -653,10 +639,7 @@ async def test_send_soft_budget_alert_email( call_args = mock_send_email.call_args[1] assert call_args["from_email"] == BaseEmailLogger.DEFAULT_LITELLM_EMAIL assert call_args["to_email"] == ["test@example.com"] - assert ( - call_args["subject"] - == "LiteLLM: Soft Budget Crossed - Total Soft Budget: $100.0" - ) + assert call_args["subject"] == "LiteLLM: Soft Budget Crossed - Total Soft Budget: $100.0" assert "$100.0" in call_args["html_body"] # soft_budget assert "$105.0" in call_args["html_body"] # spend assert "$200.0" in call_args["html_body"] # max_budget @@ -690,13 +673,13 @@ async def test_send_soft_budget_alert_email_no_max_budget( call_args = mock_send_email.call_args[1] assert "$100.0" in call_args["html_body"] # soft_budget assert "$105.0" in call_args["html_body"] # spend - assert ( - "Maximum Budget" not in call_args["html_body"] - ) # max_budget should not be shown + assert "Maximum Budget" not in call_args["html_body"] # max_budget should not be shown @pytest.mark.asyncio -async def test_budget_alerts_soft_budget_crossed(base_email_logger, mock_send_email): +async def test_budget_alerts_soft_budget_crossed( + base_email_logger, mock_send_email +): """Test that budget_alerts sends email when soft budget is crossed""" user_info = CallInfo( user_id="test_user", @@ -725,14 +708,11 @@ async def test_budget_alerts_soft_budget_crossed(base_email_logger, mock_send_em mock_send_email.assert_called_once() call_args = mock_send_email.call_args[1] assert call_args["to_email"] == ["test@example.com"] - + # Verify cache was set to prevent duplicate alerts mock_cache.async_set_cache.assert_called_once() cache_call_args = mock_cache.async_set_cache.call_args[1] - assert ( - cache_call_args["key"] - == "email_budget_alerts:soft_budget_crossed:test_user" - ) + assert cache_call_args["key"] == "email_budget_alerts:soft_budget_crossed:test_user" assert cache_call_args["value"] == "SENT" assert cache_call_args["ttl"] == EMAIL_BUDGET_ALERT_TTL @@ -786,7 +766,9 @@ async def test_budget_alerts_soft_budget_duplicate_prevention( @pytest.mark.asyncio -async def test_budget_alerts_no_budgets(base_email_logger, mock_send_email): +async def test_budget_alerts_no_budgets( + base_email_logger, mock_send_email +): """Test that budget_alerts returns early when no budgets are set""" user_info = CallInfo( user_id="test_user", @@ -835,10 +817,7 @@ async def test_budget_alerts_uses_token_for_cache_key( # Verify cache key uses token instead of user_id mock_cache.async_set_cache.assert_called_once() cache_call_args = mock_cache.async_set_cache.call_args[1] - assert ( - cache_call_args["key"] - == "email_budget_alerts:soft_budget_crossed:hashed_token_123" - ) + assert cache_call_args["key"] == "email_budget_alerts:soft_budget_crossed:hashed_token_123" @pytest.mark.asyncio @@ -859,9 +838,7 @@ async def test_get_email_params_soft_budget_crossed( ) # Should use default subject template for soft_budget_crossed - assert ( - result.subject == "LiteLLM: Soft Budget Crossed - Total Soft Budget: $100.0" - ) + assert result.subject == "LiteLLM: Soft Budget Crossed - Total Soft Budget: $100.0" assert result.recipient_email == "test@example.com" assert result.base_url == "http://test.com" @@ -890,20 +867,16 @@ async def test_budget_alerts_max_budget_alert_crossed( "PROXY_BASE_URL": "http://test.com", }, ): - await base_email_logger.budget_alerts( - type="max_budget_alert", user_info=user_info - ) + await base_email_logger.budget_alerts(type="max_budget_alert", user_info=user_info) mock_send_email.assert_called_once() call_args = mock_send_email.call_args[1] assert call_args["to_email"] == ["test@example.com"] assert "Max Budget Alert" in call_args["subject"] - + mock_cache.async_set_cache.assert_called_once() cache_call_args = mock_cache.async_set_cache.call_args[1] - assert ( - cache_call_args["key"] == "email_budget_alerts:max_budget_alert:test_user" - ) + assert cache_call_args["key"] == "email_budget_alerts:max_budget_alert:test_user" assert cache_call_args["value"] == "SENT" assert cache_call_args["ttl"] == EMAIL_BUDGET_ALERT_TTL @@ -933,15 +906,15 @@ async def test_multi_threshold_sends_crossed_thresholds( base_email_logger.internal_usage_cache = mock_cache with mock.patch.dict(os.environ, {"PROXY_BASE_URL": "http://test.com"}): - await base_email_logger.budget_alerts( - type="max_budget_alert", user_info=user_info - ) + await base_email_logger.budget_alerts(type="max_budget_alert", user_info=user_info) # spend=80 crosses 50% ($50) and 75% ($75), but not 100% ($100) assert mock_send_email.call_count == 2 # Check cache keys include threshold percentage - cache_keys = [c[1]["key"] for c in mock_cache.async_set_cache.call_args_list] + cache_keys = [ + c[1]["key"] for c in mock_cache.async_set_cache.call_args_list + ] assert "email_budget_alerts:max_budget_alert:50:hashed_key_1" in cache_keys assert "email_budget_alerts:max_budget_alert:75:hashed_key_1" in cache_keys @@ -976,9 +949,7 @@ async def test_multi_threshold_dedup_cache_prevents_resend( base_email_logger.internal_usage_cache = mock_cache with mock.patch.dict(os.environ, {"PROXY_BASE_URL": "http://test.com"}): - await base_email_logger.budget_alerts( - type="max_budget_alert", user_info=user_info - ) + await base_email_logger.budget_alerts(type="max_budget_alert", user_info=user_info) # Only 75% should fire assert mock_send_email.call_count == 1 @@ -1009,9 +980,7 @@ async def test_multi_threshold_owner_email_auto_included( base_email_logger.internal_usage_cache = mock_cache with mock.patch.dict(os.environ, {"PROXY_BASE_URL": "http://test.com"}): - await base_email_logger.budget_alerts( - type="max_budget_alert", user_info=user_info - ) + await base_email_logger.budget_alerts(type="max_budget_alert", user_info=user_info) mock_send_email.assert_called_once() to_emails = mock_send_email.call_args[1]["to_email"] @@ -1033,7 +1002,7 @@ async def test_multi_threshold_malformed_keys_skipped( event_group=Litellm_EntityType.KEY, max_budget_alert_emails={ "fifty": ["finance@co.com"], # invalid - "50": ["finance@co.com"], # valid, crossed + "50": ["finance@co.com"], # valid, crossed }, ) @@ -1043,9 +1012,7 @@ async def test_multi_threshold_malformed_keys_skipped( base_email_logger.internal_usage_cache = mock_cache with mock.patch.dict(os.environ, {"PROXY_BASE_URL": "http://test.com"}): - await base_email_logger.budget_alerts( - type="max_budget_alert", user_info=user_info - ) + await base_email_logger.budget_alerts(type="max_budget_alert", user_info=user_info) # Only the valid "50" threshold should fire assert mock_send_email.call_count == 1 @@ -1074,9 +1041,7 @@ async def test_multi_threshold_empty_emails_only_owner( base_email_logger.internal_usage_cache = mock_cache with mock.patch.dict(os.environ, {"PROXY_BASE_URL": "http://test.com"}): - await base_email_logger.budget_alerts( - type="max_budget_alert", user_info=user_info - ) + await base_email_logger.budget_alerts(type="max_budget_alert", user_info=user_info) mock_send_email.assert_called_once() to_emails = mock_send_email.call_args[1]["to_email"] @@ -1102,13 +1067,11 @@ async def test_no_map_preserves_old_single_threshold( base_email_logger.internal_usage_cache = mock_cache with mock.patch.dict(os.environ, {"PROXY_BASE_URL": "http://test.com"}): - await base_email_logger.budget_alerts( - type="max_budget_alert", user_info=user_info - ) + await base_email_logger.budget_alerts(type="max_budget_alert", user_info=user_info) mock_send_email.assert_called_once() call_args = mock_send_email.call_args[1] assert call_args["to_email"] == ["test@example.com"] # Old path cache key has no threshold percentage cache_key = mock_cache.async_set_cache.call_args[1]["key"] - assert cache_key == "email_budget_alerts:max_budget_alert:test_user" + assert cache_key == "email_budget_alerts:max_budget_alert:test_user" \ No newline at end of file diff --git a/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_resend_email.py b/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_resend_email.py index dfb51c8bbc..b07216921e 100644 --- a/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_resend_email.py +++ b/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_resend_email.py @@ -87,7 +87,7 @@ async def test_send_email_success(mock_env_vars): async def test_send_email_missing_api_key(): # Remove the API key from environment before initializing logger original_key = os.environ.pop("RESEND_API_KEY", None) - + try: # Initialize the logger after removing the API key logger = ResendEmailLogger() @@ -104,19 +104,16 @@ async def test_send_email_missing_api_key(): mock_response.raise_for_status.return_value = None mock_response.status_code = 200 mock_response.json.return_value = {"id": "test_email_id"} - + mock_async_client = mock.AsyncMock() mock_async_client.post.return_value = mock_response - + # Directly inject the mock client to bypass any caching logger.async_httpx_client = mock_async_client # Send email await logger.send_email( - from_email=from_email, - to_email=to_email, - subject=subject, - html_body=html_body, + from_email=from_email, to_email=to_email, subject=subject, html_body=html_body ) # Verify the HTTP client was called with None as the API key diff --git a/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_sendgrid_email.py b/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_sendgrid_email.py index d247b02074..40439a78a4 100644 --- a/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_sendgrid_email.py +++ b/tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_sendgrid_email.py @@ -32,12 +32,12 @@ def mock_env_vars(): # Store original values original_api_key = os.environ.get("SENDGRID_API_KEY") original_sender_email = os.environ.get("SENDGRID_SENDER_EMAIL") - + # Set test API key and remove SENDGRID_SENDER_EMAIL to ensure isolation os.environ["SENDGRID_API_KEY"] = "test_api_key" if "SENDGRID_SENDER_EMAIL" in os.environ: del os.environ["SENDGRID_SENDER_EMAIL"] - + try: yield finally: @@ -46,7 +46,7 @@ def mock_env_vars(): os.environ["SENDGRID_API_KEY"] = original_api_key elif "SENDGRID_API_KEY" in os.environ: del os.environ["SENDGRID_API_KEY"] - + if original_sender_email is not None: os.environ["SENDGRID_SENDER_EMAIL"] = original_sender_email diff --git a/tests/test_litellm/enterprise/enterprise_callbacks/test_callback_controls.py b/tests/test_litellm/enterprise/enterprise_callbacks/test_callback_controls.py index 9743e7bedc..b160ca5130 100644 --- a/tests/test_litellm/enterprise/enterprise_callbacks/test_callback_controls.py +++ b/tests/test_litellm/enterprise/enterprise_callbacks/test_callback_controls.py @@ -18,282 +18,168 @@ from litellm.types.utils import StandardCallbackDynamicParams class TestEnterpriseCallbackControls: - + @pytest.fixture def mock_premium_user(self): """Fixture to mock premium user check as True""" - with patch.object( - EnterpriseCallbackControls, - "_should_allow_dynamic_callback_disabling", - return_value=True, - ): + with patch.object(EnterpriseCallbackControls, '_should_allow_dynamic_callback_disabling', return_value=True): yield - - @pytest.fixture + + @pytest.fixture def mock_non_premium_user(self): """Fixture to mock premium user check as False""" - with patch.object( - EnterpriseCallbackControls, - "_should_allow_dynamic_callback_disabling", - return_value=False, - ): + with patch.object(EnterpriseCallbackControls, '_should_allow_dynamic_callback_disabling', return_value=False): yield @pytest.fixture def mock_request_headers(self): """Fixture to mock get_proxy_server_request_headers""" - with patch( - "enterprise.litellm_enterprise.enterprise_callbacks.callback_controls.get_proxy_server_request_headers" - ) as mock_headers: + with patch('enterprise.litellm_enterprise.enterprise_callbacks.callback_controls.get_proxy_server_request_headers') as mock_headers: yield mock_headers - def test_callback_disabled_langfuse_string( - self, mock_premium_user, mock_request_headers - ): + def test_callback_disabled_langfuse_string(self, mock_premium_user, mock_request_headers): """Test that 'langfuse' string callback is disabled when specified in headers""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is True - def test_callback_disabled_langfuse_customlogger( - self, mock_premium_user, mock_request_headers - ): + def test_callback_disabled_langfuse_customlogger(self, mock_premium_user, mock_request_headers): """Test that LangfusePromptManagement CustomLogger instance is disabled when 'langfuse' specified in headers""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + langfuse_logger = LangfusePromptManagement() - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - langfuse_logger, litellm_params, standard_callback_dynamic_params - ) + result = EnterpriseCallbackControls.is_callback_disabled_dynamically(langfuse_logger, litellm_params, standard_callback_dynamic_params) assert result is True - def test_callback_disabled_s3_v2_string( - self, mock_premium_user, mock_request_headers - ): + def test_callback_disabled_s3_v2_string(self, mock_premium_user, mock_request_headers): """Test that 's3_v2' string callback is disabled when specified in headers""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "s3_v2"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "s3_v2", litellm_params, standard_callback_dynamic_params - ) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) assert result is True - def test_callback_disabled_s3_v2_customlogger( - self, mock_premium_user, mock_request_headers - ): + def test_callback_disabled_s3_v2_customlogger(self, mock_premium_user, mock_request_headers): """Test that S3Logger CustomLogger instance is disabled when 's3_v2' specified in headers""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "s3_v2"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + # Mock S3Logger to avoid async initialization issues - with patch("litellm.integrations.s3_v2.S3Logger.__init__", return_value=None): + with patch('litellm.integrations.s3_v2.S3Logger.__init__', return_value=None): s3_logger = S3Logger() - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - s3_logger, litellm_params, standard_callback_dynamic_params - ) + result = EnterpriseCallbackControls.is_callback_disabled_dynamically(s3_logger, litellm_params, standard_callback_dynamic_params) assert result is True - def test_callback_disabled_datadog_string( - self, mock_premium_user, mock_request_headers - ): + def test_callback_disabled_datadog_string(self, mock_premium_user, mock_request_headers): """Test that 'datadog' string callback is disabled when specified in headers""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "datadog"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "datadog", litellm_params, standard_callback_dynamic_params - ) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) assert result is True - def test_callback_disabled_datadog_customlogger( - self, mock_premium_user, mock_request_headers - ): + def test_callback_disabled_datadog_customlogger(self, mock_premium_user, mock_request_headers): """Test that DataDogLogger CustomLogger instance is disabled when 'datadog' specified in headers""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "datadog"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + # Mock DataDogLogger to avoid async initialization issues - with patch( - "litellm.integrations.datadog.datadog.DataDogLogger.__init__", - return_value=None, - ): + with patch('litellm.integrations.datadog.datadog.DataDogLogger.__init__', return_value=None): datadog_logger = DataDogLogger() - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - datadog_logger, litellm_params, standard_callback_dynamic_params - ) + result = EnterpriseCallbackControls.is_callback_disabled_dynamically(datadog_logger, litellm_params, standard_callback_dynamic_params) assert result is True def test_multiple_callbacks_disabled(self, mock_premium_user, mock_request_headers): """Test that multiple callbacks can be disabled with comma-separated list""" - mock_request_headers.return_value = { - X_LITELLM_DISABLE_CALLBACKS: "langfuse,datadog,s3_v2" - } + mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse,datadog,s3_v2"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + # Test each callback is disabled - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) - is True - ) - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "datadog", litellm_params, standard_callback_dynamic_params - ) - is True - ) - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "s3_v2", litellm_params, standard_callback_dynamic_params - ) - is True - ) - + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True + # Test non-disabled callback is not disabled - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "prometheus", litellm_params, standard_callback_dynamic_params - ) - is False - ) + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("prometheus", litellm_params, standard_callback_dynamic_params) is False - def test_callback_not_disabled_when_not_in_list( - self, mock_premium_user, mock_request_headers - ): + def test_callback_not_disabled_when_not_in_list(self, mock_premium_user, mock_request_headers): """Test that callbacks not in the disabled list are not disabled""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "datadog", litellm_params, standard_callback_dynamic_params - ) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) assert result is False - def test_callback_not_disabled_when_no_header( - self, mock_premium_user, mock_request_headers - ): + def test_callback_not_disabled_when_no_header(self, mock_premium_user, mock_request_headers): """Test that callbacks are not disabled when the header is not present""" mock_request_headers.return_value = {} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is False - def test_callback_not_disabled_when_header_none( - self, mock_premium_user, mock_request_headers - ): + def test_callback_not_disabled_when_header_none(self, mock_premium_user, mock_request_headers): """Test that callbacks are not disabled when the header value is None""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: None} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is False - def test_non_premium_user_cannot_disable_callbacks( - self, mock_non_premium_user, mock_request_headers - ): + def test_non_premium_user_cannot_disable_callbacks(self, mock_non_premium_user, mock_request_headers): """Test that non-premium users cannot disable callbacks even with the header""" mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is False - def test_case_insensitive_callback_matching( - self, mock_premium_user, mock_request_headers - ): + def test_case_insensitive_callback_matching(self, mock_premium_user, mock_request_headers): """Test that callback matching is case insensitive""" - mock_request_headers.return_value = { - X_LITELLM_DISABLE_CALLBACKS: "LANGFUSE,DataDog" - } + mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "LANGFUSE,DataDog"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + # Test lowercase callbacks are disabled - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) - is True - ) - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "datadog", litellm_params, standard_callback_dynamic_params - ) - is True - ) + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True - def test_whitespace_handling_in_disabled_callbacks( - self, mock_premium_user, mock_request_headers - ): + def test_whitespace_handling_in_disabled_callbacks(self, mock_premium_user, mock_request_headers): """Test that whitespace around callback names is handled correctly""" - mock_request_headers.return_value = { - X_LITELLM_DISABLE_CALLBACKS: " langfuse , datadog , s3_v2 " - } + mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: " langfuse , datadog , s3_v2 "} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() + + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) - is True - ) - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "datadog", litellm_params, standard_callback_dynamic_params - ) - is True - ) - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "s3_v2", litellm_params, standard_callback_dynamic_params - ) - is True - ) - - def test_custom_logger_not_in_registry( - self, mock_premium_user, mock_request_headers - ): + def test_custom_logger_not_in_registry(self, mock_premium_user, mock_request_headers): """Test that CustomLogger not in registry is not disabled""" - mock_request_headers.return_value = { - X_LITELLM_DISABLE_CALLBACKS: "unknown_logger" - } + mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "unknown_logger"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + # Create a mock CustomLogger that's not in the registry class UnknownLogger(CustomLogger): pass - + unknown_logger = UnknownLogger() - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - unknown_logger, litellm_params, standard_callback_dynamic_params - ) + result = EnterpriseCallbackControls.is_callback_disabled_dynamically(unknown_logger, litellm_params, standard_callback_dynamic_params) assert result is False def test_exception_handling(self, mock_premium_user, mock_request_headers): @@ -302,64 +188,32 @@ class TestEnterpriseCallbackControls: mock_request_headers.side_effect = Exception("Test exception") litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is False - def test_callback_disabled_via_request_body_langfuse( - self, mock_premium_user, mock_request_headers - ): + def test_callback_disabled_via_request_body_langfuse(self, mock_premium_user, mock_request_headers): """Test that callbacks can be disabled via request body litellm_disabled_callbacks""" mock_request_headers.return_value = {} # No headers litellm_params = {"proxy_server_request": {"url": "test"}} - standard_callback_dynamic_params = StandardCallbackDynamicParams( - litellm_disabled_callbacks=["langfuse"] - ) - - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + standard_callback_dynamic_params = StandardCallbackDynamicParams(litellm_disabled_callbacks=["langfuse"]) + + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is True - def test_callback_disabled_via_request_body_multiple( - self, mock_premium_user, mock_request_headers - ): + def test_callback_disabled_via_request_body_multiple(self, mock_premium_user, mock_request_headers): """Test that multiple callbacks can be disabled via request body""" mock_request_headers.return_value = {} # No headers litellm_params = {"proxy_server_request": {"url": "test"}} - standard_callback_dynamic_params = StandardCallbackDynamicParams( - litellm_disabled_callbacks=["langfuse", "datadog", "s3_v2"] - ) - + standard_callback_dynamic_params = StandardCallbackDynamicParams(litellm_disabled_callbacks=["langfuse", "datadog", "s3_v2"]) + # Test each callback is disabled - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) - is True - ) - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "datadog", litellm_params, standard_callback_dynamic_params - ) - is True - ) - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "s3_v2", litellm_params, standard_callback_dynamic_params - ) - is True - ) - + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True + # Test non-disabled callback is not disabled - assert ( - EnterpriseCallbackControls.is_callback_disabled_dynamically( - "prometheus", litellm_params, standard_callback_dynamic_params - ) - is False - ) + assert EnterpriseCallbackControls.is_callback_disabled_dynamically("prometheus", litellm_params, standard_callback_dynamic_params) is False def test_admin_can_disable_dynamic_callback_disabling(self, mock_request_headers): """ @@ -369,13 +223,11 @@ class TestEnterpriseCallbackControls: mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + # Mock litellm.allow_dynamic_callback_disabling set to False - with patch("litellm.allow_dynamic_callback_disabling", False): - with patch("litellm.proxy.proxy_server.premium_user", True): - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + with patch('litellm.allow_dynamic_callback_disabling', False): + with patch('litellm.proxy.proxy_server.premium_user', True): + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is False def test_admin_can_enable_dynamic_callback_disabling(self, mock_request_headers): @@ -386,18 +238,14 @@ class TestEnterpriseCallbackControls: mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + # Mock litellm.allow_dynamic_callback_disabling set to True - with patch("litellm.allow_dynamic_callback_disabling", True): - with patch("litellm.proxy.proxy_server.premium_user", True): - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + with patch('litellm.allow_dynamic_callback_disabling', True): + with patch('litellm.proxy.proxy_server.premium_user', True): + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is True - def test_default_admin_setting_allows_dynamic_callback_disabling( - self, mock_request_headers - ): + def test_default_admin_setting_allows_dynamic_callback_disabling(self, mock_request_headers): """ Test that when allow_dynamic_callback_disabling is not set, it defaults to True and allows dynamic callback disabling for premium users @@ -405,10 +253,8 @@ class TestEnterpriseCallbackControls: mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"} litellm_params = {"proxy_server_request": {"url": "test"}} standard_callback_dynamic_params = StandardCallbackDynamicParams() - + # litellm.allow_dynamic_callback_disabling should default to True - with patch("litellm.proxy.proxy_server.premium_user", True): - result = EnterpriseCallbackControls.is_callback_disabled_dynamically( - "langfuse", litellm_params, standard_callback_dynamic_params - ) + with patch('litellm.proxy.proxy_server.premium_user', True): + result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) assert result is True diff --git a/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py index 180041b5f7..6e9c3c0354 100644 --- a/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py +++ b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py @@ -16,18 +16,13 @@ from litellm.proxy.openai_files_endpoints.common_utils import ( _is_base64_encoded_unified_file_id, ) + DECODED_UNIFIED_INPUT_FILE_ID = "litellm_proxy:application/octet-stream;unified_id,test-uuid;target_model_names,azure-gpt-4" -B64_UNIFIED_INPUT_FILE_ID = ( - base64.urlsafe_b64encode(DECODED_UNIFIED_INPUT_FILE_ID.encode()) - .decode() - .rstrip("=") -) +B64_UNIFIED_INPUT_FILE_ID = base64.urlsafe_b64encode(DECODED_UNIFIED_INPUT_FILE_ID.encode()).decode().rstrip("=") RAW_INPUT_FILE_ID = "file-raw-provider-abc123" DECODED_UNIFIED_BATCH_ID = "litellm_proxy;model_id:model-xyz;llm_batch_id:batch-123" -B64_UNIFIED_BATCH_ID = ( - base64.urlsafe_b64encode(DECODED_UNIFIED_BATCH_ID.encode()).decode().rstrip("=") -) +B64_UNIFIED_BATCH_ID = base64.urlsafe_b64encode(DECODED_UNIFIED_BATCH_ID.encode()).decode().rstrip("=") @pytest.mark.asyncio @@ -60,16 +55,10 @@ async def test_should_resolve_raw_input_file_id_to_unified(): mock_managed_file.unified_file_id = B64_UNIFIED_INPUT_FILE_ID mock_prisma = MagicMock() - mock_prisma.db.litellm_managedobjecttable.find_first = AsyncMock( - return_value=mock_db_object - ) - mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock( - return_value=mock_managed_file - ) + mock_prisma.db.litellm_managedobjecttable.find_first = AsyncMock(return_value=mock_db_object) + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock(return_value=mock_managed_file) - from litellm.proxy.openai_files_endpoints.common_utils import ( - get_batch_from_database, - ) + from litellm.proxy.openai_files_endpoints.common_utils import get_batch_from_database _, response = await get_batch_from_database( batch_id=B64_UNIFIED_BATCH_ID, diff --git a/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py index c80b8a848c..420f5f9789 100644 --- a/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py +++ b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py @@ -93,9 +93,7 @@ async def test_should_preserve_already_managed_input_file_id(): unified_batch_id = "bGl0ZWxsbV9wcm94eTpiYXRjaF9pZA" decoded_unified = "litellm_proxy:application/octet-stream;unified_id,test-123" - base64_input_file_id = ( - base64.urlsafe_b64encode(decoded_unified.encode()).decode().rstrip("=") - ) + base64_input_file_id = base64.urlsafe_b64encode(decoded_unified.encode()).decode().rstrip("=") batch_data = { "id": "batch-raw-123", diff --git a/tests/test_litellm/enterprise/proxy/test_enterprise_routes.py b/tests/test_litellm/enterprise/proxy/test_enterprise_routes.py index 34c8c9c5c5..a9bf33a21a 100644 --- a/tests/test_litellm/enterprise/proxy/test_enterprise_routes.py +++ b/tests/test_litellm/enterprise/proxy/test_enterprise_routes.py @@ -14,77 +14,63 @@ import pytest def test_enterprise_routes_all_imports_exist(): """ Validate that all relative imports in enterprise_routes.py exist in the filesystem. - + This catches any import errors from moved/deleted modules without hardcoding specific module names. Works by checking that imported files actually exist. """ # Path to the enterprise_routes.py source file enterprise_routes_path = os.path.join( os.path.dirname(__file__), - "..", - "..", - "..", - "..", - "enterprise", - "litellm_enterprise", - "proxy", - "enterprise_routes.py", + "..", "..", "..", "..", + "enterprise", "litellm_enterprise", "proxy", "enterprise_routes.py" ) - + enterprise_routes_path = os.path.normpath(enterprise_routes_path) enterprise_proxy_dir = os.path.dirname(enterprise_routes_path) - + if not os.path.exists(enterprise_routes_path): pytest.skip(f"Enterprise routes file not found at {enterprise_routes_path}") - + # Read and parse the source file with open(enterprise_routes_path, "r") as f: source_code = f.read() - + try: tree = ast.parse(source_code) except SyntaxError as e: pytest.fail(f"Syntax error in enterprise_routes.py: {e}") - + # Check all relative imports missing_imports = [] - + for node in ast.walk(tree): if isinstance(node, ast.ImportFrom): # level > 0 means it's a relative import (. or .. etc) if node.level and node.level > 0: module = node.module or "" - + # Convert relative import to file path # e.g., "audit_logging_endpoints" -> "audit_logging_endpoints.py" # e.g., "vector_stores.endpoints" -> "vector_stores/endpoints.py" module_path = module.replace(".", os.sep) if module else "" - + # Check both .py file and package directory - file_path = ( - os.path.join(enterprise_proxy_dir, module_path + ".py") - if module_path - else None - ) - package_path = ( - os.path.join(enterprise_proxy_dir, module_path, "__init__.py") - if module_path - else None - ) - + file_path = os.path.join(enterprise_proxy_dir, module_path + ".py") if module_path else None + package_path = os.path.join(enterprise_proxy_dir, module_path, "__init__.py") if module_path else None + # If module is empty (e.g., "from . import something"), skip check if not module: continue - + file_exists = file_path and os.path.exists(file_path) package_exists = package_path and os.path.exists(package_path) - + if not file_exists and not package_exists: missing_imports.append( f"Line {node.lineno}: Cannot find '.{module}' " f"(checked: {file_path} and {package_path})" ) - + if missing_imports: error_msg = "Found imports in enterprise_routes.py that don't exist:\n" error_msg += "\n".join(missing_imports) diff --git a/tests/test_litellm/enterprise/proxy/test_file_deletion_blocking.py b/tests/test_litellm/enterprise/proxy/test_file_deletion_blocking.py index 3c7aace7d3..852077dcf0 100644 --- a/tests/test_litellm/enterprise/proxy/test_file_deletion_blocking.py +++ b/tests/test_litellm/enterprise/proxy/test_file_deletion_blocking.py @@ -61,7 +61,7 @@ def _make_managed_files_instance_with_batches( ): """ Create a _PROXY_LiteLLMManagedFiles instance with mocked DB and batches. - + Args: file_id: The unified file ID batches: List of batch records to return from DB @@ -79,7 +79,7 @@ def _make_managed_files_instance_with_batches( # Mock prisma mock_prisma = MagicMock() - + # Mock file table queries mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock( return_value=mock_file_record @@ -87,7 +87,7 @@ def _make_managed_files_instance_with_batches( mock_prisma.db.litellm_managedfiletable.delete = AsyncMock( return_value=mock_file_record ) - + # Mock batch/object table queries mock_prisma.db.litellm_managedobjecttable.find_many = AsyncMock( return_value=batches @@ -95,13 +95,11 @@ def _make_managed_files_instance_with_batches( # Mock cache mock_cache = MagicMock() - mock_cache.async_get_cache = AsyncMock( - return_value={ - "unified_file_id": file_id, - "model_mappings": {"model-123": "provider-file-abc"}, - "flat_model_file_ids": ["provider-file-abc"], - } - ) + mock_cache.async_get_cache = AsyncMock(return_value={ + "unified_file_id": file_id, + "model_mappings": {"model-123": "provider-file-abc"}, + "flat_model_file_ids": ["provider-file-abc"], + }) mock_cache.async_set_cache = AsyncMock() instance = _PROXY_LiteLLMManagedFiles( @@ -119,17 +117,17 @@ def test_is_batch_polling_enabled_when_job_registered(): from litellm_enterprise.proxy.hooks.managed_files import ( _PROXY_LiteLLMManagedFiles, ) - + instance = _PROXY_LiteLLMManagedFiles( internal_usage_cache=MagicMock(), prisma_client=MagicMock(), ) - + # Mock scheduler with registered job mock_scheduler = MagicMock() mock_job = MagicMock() mock_scheduler.get_job.return_value = mock_job - + with patch("litellm.proxy.proxy_server.scheduler", mock_scheduler): assert instance._is_batch_polling_enabled() is True @@ -139,16 +137,16 @@ def test_is_batch_polling_disabled_when_job_not_registered(): from litellm_enterprise.proxy.hooks.managed_files import ( _PROXY_LiteLLMManagedFiles, ) - + instance = _PROXY_LiteLLMManagedFiles( internal_usage_cache=MagicMock(), prisma_client=MagicMock(), ) - + # Mock scheduler without registered job mock_scheduler = MagicMock() mock_scheduler.get_job.return_value = None - + with patch("litellm.proxy.proxy_server.scheduler", mock_scheduler): assert instance._is_batch_polling_enabled() is False @@ -158,12 +156,12 @@ def test_is_batch_polling_disabled_when_no_scheduler(): from litellm_enterprise.proxy.hooks.managed_files import ( _PROXY_LiteLLMManagedFiles, ) - + instance = _PROXY_LiteLLMManagedFiles( internal_usage_cache=MagicMock(), prisma_client=MagicMock(), ) - + with patch("litellm.proxy.proxy_server.scheduler", None): assert instance._is_batch_polling_enabled() is False @@ -176,28 +174,26 @@ async def test_get_batches_referencing_file_finds_batch_with_input_file(): """Test finding a batch that references the file as input_file_id.""" unified_file_id = _make_unified_file_id("file-input-123") unified_batch_id = _make_unified_batch_id("batch-123") - + batch_file_object = { "id": "batch-123", "input_file_id": unified_file_id, # Batch references this file "status": "validating", } - + batch_record = _make_batch_db_record( unified_object_id=unified_batch_id, status="validating", file_object=batch_file_object, ) - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[batch_record], ) - - referencing_batches = await managed_files._get_batches_referencing_file( - unified_file_id - ) - + + referencing_batches = await managed_files._get_batches_referencing_file(unified_file_id) + assert len(referencing_batches) == 1 assert referencing_batches[0]["batch_id"] == unified_batch_id assert referencing_batches[0]["status"] == "validating" @@ -208,29 +204,27 @@ async def test_get_batches_referencing_file_finds_batch_with_output_file(): """Test finding a batch that references the file as output_file_id.""" unified_file_id = _make_unified_file_id("file-output-456") unified_batch_id = _make_unified_batch_id("batch-456") - + batch_file_object = { "id": "batch-456", "input_file_id": "file-input-different", "output_file_id": unified_file_id, # Batch references this file "status": "in_progress", } - + batch_record = _make_batch_db_record( unified_object_id=unified_batch_id, status="in_progress", file_object=batch_file_object, ) - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[batch_record], ) - - referencing_batches = await managed_files._get_batches_referencing_file( - unified_file_id - ) - + + referencing_batches = await managed_files._get_batches_referencing_file(unified_file_id) + assert len(referencing_batches) == 1 assert referencing_batches[0]["status"] == "in_progress" @@ -240,29 +234,27 @@ async def test_get_batches_referencing_file_ignores_terminal_batches(): """Test that batches in terminal states are not returned.""" unified_file_id = _make_unified_file_id("file-123") unified_batch_id = _make_unified_batch_id("batch-completed") - + batch_file_object = { "id": "batch-completed", "input_file_id": unified_file_id, "status": "completed", } - + # Batch is in terminal state in DB batch_record = _make_batch_db_record( unified_object_id=unified_batch_id, status="completed", # Terminal state file_object=batch_file_object, ) - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[], # Query returns no batches (terminal states filtered out) ) - - referencing_batches = await managed_files._get_batches_referencing_file( - unified_file_id - ) - + + referencing_batches = await managed_files._get_batches_referencing_file(unified_file_id) + assert len(referencing_batches) == 0 @@ -270,36 +262,26 @@ async def test_get_batches_referencing_file_ignores_terminal_batches(): async def test_get_batches_referencing_file_finds_multiple_batches(): """Test finding multiple batches referencing the same file.""" unified_file_id = _make_unified_file_id("file-shared") - + batch1 = _make_batch_db_record( unified_object_id=_make_unified_batch_id("batch-1"), status="validating", - file_object={ - "id": "batch-1", - "input_file_id": unified_file_id, - "status": "validating", - }, + file_object={"id": "batch-1", "input_file_id": unified_file_id, "status": "validating"}, ) - + batch2 = _make_batch_db_record( unified_object_id=_make_unified_batch_id("batch-2"), status="in_progress", - file_object={ - "id": "batch-2", - "input_file_id": unified_file_id, - "status": "in_progress", - }, + file_object={"id": "batch-2", "input_file_id": unified_file_id, "status": "in_progress"}, ) - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[batch1, batch2], ) - - referencing_batches = await managed_files._get_batches_referencing_file( - unified_file_id - ) - + + referencing_batches = await managed_files._get_batches_referencing_file(unified_file_id) + assert len(referencing_batches) == 2 statuses = [b["status"] for b in referencing_batches] assert "validating" in statuses @@ -318,32 +300,32 @@ async def test_file_deletion_blocked_when_batch_polling_enabled_and_batch_refere """ unified_file_id = _make_unified_file_id("file-to-delete") unified_batch_id = _make_unified_batch_id("batch-active") - + batch_file_object = { "id": "batch-active", "input_file_id": unified_file_id, "status": "validating", } - + batch_record = _make_batch_db_record( unified_object_id=unified_batch_id, status="validating", file_object=batch_file_object, ) - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[batch_record], ) - + # Mock scheduler with registered batch cost job mock_scheduler = MagicMock() mock_scheduler.get_job.return_value = MagicMock() # Job exists - + with patch("litellm.proxy.proxy_server.scheduler", mock_scheduler): with pytest.raises(HTTPException) as exc_info: await managed_files._check_file_deletion_allowed(unified_file_id) - + assert exc_info.value.status_code == 400 error_detail = exc_info.value.detail assert "Cannot delete file" in error_detail @@ -360,28 +342,28 @@ async def test_file_deletion_allowed_when_batch_polling_disabled(): """ unified_file_id = _make_unified_file_id("file-to-delete") unified_batch_id = _make_unified_batch_id("batch-active") - + batch_file_object = { "id": "batch-active", "input_file_id": unified_file_id, "status": "validating", } - + batch_record = _make_batch_db_record( unified_object_id=unified_batch_id, status="validating", file_object=batch_file_object, ) - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[batch_record], ) - + # Mock scheduler without registered job (batch cost tracking disabled) mock_scheduler = MagicMock() mock_scheduler.get_job.return_value = None - + with patch("litellm.proxy.proxy_server.scheduler", mock_scheduler): # Should not raise an exception await managed_files._check_file_deletion_allowed(unified_file_id) @@ -394,16 +376,16 @@ async def test_file_deletion_allowed_when_no_batches_reference_file(): even when batch cost tracking is enabled. """ unified_file_id = _make_unified_file_id("file-to-delete") - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[], # No batches reference this file ) - + # Mock scheduler with registered job (batch cost tracking enabled) mock_scheduler = MagicMock() mock_scheduler.get_job.return_value = MagicMock() - + with patch("litellm.proxy.proxy_server.scheduler", mock_scheduler): # Should not raise an exception await managed_files._check_file_deletion_allowed(unified_file_id) @@ -416,32 +398,32 @@ async def test_afile_delete_calls_check_deletion_allowed(): """ unified_file_id = _make_unified_file_id("file-to-delete") unified_batch_id = _make_unified_batch_id("batch-active") - + batch_file_object = { "id": "batch-active", "input_file_id": unified_file_id, "status": "in_progress", } - + batch_record = _make_batch_db_record( unified_object_id=unified_batch_id, status="in_progress", file_object=batch_file_object, ) - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[batch_record], ) - + # Mock llm_router mock_router = MagicMock() mock_router.afile_delete = AsyncMock() - + # Mock scheduler with registered job mock_scheduler = MagicMock() mock_scheduler.get_job.return_value = MagicMock() - + with patch("litellm.proxy.proxy_server.scheduler", mock_scheduler): with pytest.raises(HTTPException) as exc_info: await managed_files.afile_delete( @@ -449,7 +431,7 @@ async def test_afile_delete_calls_check_deletion_allowed(): litellm_parent_otel_span=None, llm_router=mock_router, ) - + # Should raise error before calling router delete assert exc_info.value.status_code == 400 mock_router.afile_delete.assert_not_called() @@ -462,7 +444,7 @@ async def test_database_limit_respected(): This is a performance optimization - we only fetch what we need. """ unified_file_id = _make_unified_file_id("file-shared") - + # Create exactly 10 batches (what DB will return with take=10) ten_batches = [] for i in range(10): @@ -472,32 +454,30 @@ async def test_database_limit_respected(): file_object={ "id": f"batch-{i}", "input_file_id": unified_file_id, - "status": "validating", + "status": "validating" }, ) ten_batches.append(batch) - + # Mock will return only 10 batches (as DB would with take=10) managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=ten_batches, ) - - referencing_batches = await managed_files._get_batches_referencing_file( - unified_file_id - ) - + + referencing_batches = await managed_files._get_batches_referencing_file(unified_file_id) + # Should return all 10 that reference the file assert len(referencing_batches) == 10 - + # Verify error message handles "10+" case (since we got exactly 10, might be more in DB) mock_scheduler = MagicMock() mock_scheduler.get_job.return_value = MagicMock() - + with patch("litellm.proxy.proxy_server.scheduler", mock_scheduler): with pytest.raises(HTTPException) as exc_info: await managed_files._check_file_deletion_allowed(unified_file_id) - + error_detail = exc_info.value.detail # When we get exactly 10 matches, show "10+" to indicate there might be more assert "10+ batch(es)" in error_detail @@ -511,40 +491,32 @@ async def test_error_message_includes_batch_details(): unified_file_id = _make_unified_file_id("file-to-delete") batch1_id = _make_unified_batch_id("batch-1") batch2_id = _make_unified_batch_id("batch-2") - + batch1 = _make_batch_db_record( unified_object_id=batch1_id, status="validating", - file_object={ - "id": "batch-1", - "input_file_id": unified_file_id, - "status": "validating", - }, + file_object={"id": "batch-1", "input_file_id": unified_file_id, "status": "validating"}, ) - + batch2 = _make_batch_db_record( unified_object_id=batch2_id, status="in_progress", - file_object={ - "id": "batch-2", - "output_file_id": unified_file_id, - "status": "in_progress", - }, + file_object={"id": "batch-2", "output_file_id": unified_file_id, "status": "in_progress"}, ) - + managed_files = _make_managed_files_instance_with_batches( file_id=unified_file_id, batches=[batch1, batch2], ) - + # Mock scheduler with registered job mock_scheduler = MagicMock() mock_scheduler.get_job.return_value = MagicMock() - + with patch("litellm.proxy.proxy_server.scheduler", mock_scheduler): with pytest.raises(HTTPException) as exc_info: await managed_files._check_file_deletion_allowed(unified_file_id) - + error_detail = exc_info.value.detail assert "2 batch(es)" in error_detail assert "validating" in error_detail diff --git a/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py b/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py index d4855d5b13..9526304aff 100644 --- a/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py +++ b/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py @@ -110,9 +110,10 @@ async def test_should_pass_credentials_to_afile_retrieve(): mock_afile_retrieve = AsyncMock(return_value=_make_file_object("file-output-abc")) - with ( - patch("litellm.afile_retrieve", mock_afile_retrieve), - patch("litellm.proxy.proxy_server.llm_router", mock_router), + with patch( + "litellm.afile_retrieve", mock_afile_retrieve + ), patch( + "litellm.proxy.proxy_server.llm_router", mock_router ): await managed_files.async_post_call_success_hook( data={}, @@ -127,9 +128,7 @@ async def test_should_pass_credentials_to_afile_retrieve(): f"afile_retrieve must receive api_key from router credentials. " f"Got kwargs: {call_kwargs.kwargs}" ) - assert ( - call_kwargs.kwargs.get("api_base") == "https://my-azure.openai.azure.com/" - ), ( + assert call_kwargs.kwargs.get("api_base") == "https://my-azure.openai.azure.com/", ( f"afile_retrieve must receive api_base from router credentials. " f"Got kwargs: {call_kwargs.kwargs}" ) @@ -151,9 +150,10 @@ async def test_should_fallback_when_no_router(): mock_afile_retrieve = AsyncMock(return_value=_make_file_object("file-output-abc")) - with ( - patch("litellm.afile_retrieve", mock_afile_retrieve), - patch("litellm.proxy.proxy_server.llm_router", None), + with patch( + "litellm.afile_retrieve", mock_afile_retrieve + ), patch( + "litellm.proxy.proxy_server.llm_router", None ): await managed_files.async_post_call_success_hook( data={}, diff --git a/tests/test_litellm/google_genai/test_google_genai_adapter.py b/tests/test_litellm/google_genai/test_google_genai_adapter.py index 90ef89699e..f21564546a 100644 --- a/tests/test_litellm/google_genai/test_google_genai_adapter.py +++ b/tests/test_litellm/google_genai/test_google_genai_adapter.py @@ -2,7 +2,6 @@ """ Test to verify the Google GenAI generate_content adapter functionality """ - import json import os import sys diff --git a/tests/test_litellm/google_genai/test_google_genai_adapter_fixes.py b/tests/test_litellm/google_genai/test_google_genai_adapter_fixes.py index 8fc3eca9ad..da56b094d9 100644 --- a/tests/test_litellm/google_genai/test_google_genai_adapter_fixes.py +++ b/tests/test_litellm/google_genai/test_google_genai_adapter_fixes.py @@ -2,7 +2,6 @@ """ Test to verify the Google GenAI adapter fixes """ - import json import os import sys diff --git a/tests/test_litellm/google_genai/test_google_genai_handler.py b/tests/test_litellm/google_genai/test_google_genai_handler.py index da1c0bb611..0dc218d297 100644 --- a/tests/test_litellm/google_genai/test_google_genai_handler.py +++ b/tests/test_litellm/google_genai/test_google_genai_handler.py @@ -2,7 +2,6 @@ """ Test to verify the Google GenAI generate_content handler functionality """ - import json import os import sys diff --git a/tests/test_litellm/google_genai/test_google_genai_main.py b/tests/test_litellm/google_genai/test_google_genai_main.py index 1801ee10f3..5854e4b55a 100644 --- a/tests/test_litellm/google_genai/test_google_genai_main.py +++ b/tests/test_litellm/google_genai/test_google_genai_main.py @@ -2,7 +2,6 @@ """ Test to verify the Google GenAI generate_content adapter functionality """ - import json import os import sys diff --git a/tests/test_litellm/google_genai/test_google_genai_transformation.py b/tests/test_litellm/google_genai/test_google_genai_transformation.py index 691e547299..908a68110f 100644 --- a/tests/test_litellm/google_genai/test_google_genai_transformation.py +++ b/tests/test_litellm/google_genai/test_google_genai_transformation.py @@ -2,7 +2,6 @@ """ Test to verify the Google GenAI transformation logic for generateContent parameters """ - import os import sys diff --git a/tests/test_litellm/integrations/arize/test_arize_otel_coexistence.py b/tests/test_litellm/integrations/arize/test_arize_otel_coexistence.py index 0767b83e8e..fdf56aedbc 100644 --- a/tests/test_litellm/integrations/arize/test_arize_otel_coexistence.py +++ b/tests/test_litellm/integrations/arize/test_arize_otel_coexistence.py @@ -20,6 +20,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanE from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/integrations/dotprompt/test_prompt_manager.py b/tests/test_litellm/integrations/dotprompt/test_prompt_manager.py index 78f5259fc5..d849582b3c 100644 --- a/tests/test_litellm/integrations/dotprompt/test_prompt_manager.py +++ b/tests/test_litellm/integrations/dotprompt/test_prompt_manager.py @@ -127,14 +127,16 @@ def test_input_validation(): # Create a temporary directory with a test prompt with tempfile.TemporaryDirectory() as temp_dir: prompt_file = Path(temp_dir) / "test_validation.prompt" - prompt_file.write_text("""--- + prompt_file.write_text( + """--- input: schema: name: string age: integer active: boolean --- -Hello {{name}}, you are {{age}} years old and {'active' if active else 'inactive'}.""") +Hello {{name}}, you are {{age}} years old and {'active' if active else 'inactive'}.""" + ) manager = PromptManager(prompt_directory=str(temp_dir)) @@ -246,14 +248,16 @@ def test_frontmatter_parsing(): with tempfile.TemporaryDirectory() as temp_dir: # Test with frontmatter prompt_with_frontmatter = Path(temp_dir) / "with_frontmatter.prompt" - prompt_with_frontmatter.write_text("""--- + prompt_with_frontmatter.write_text( + """--- model: gpt-4 temperature: 0.8 input: schema: topic: string --- -Write about {{topic}}.""") +Write about {{topic}}.""" + ) # Test without frontmatter prompt_without_frontmatter = Path(temp_dir) / "without_frontmatter.prompt" @@ -366,7 +370,8 @@ def test_prompt_file_to_json_conversion(): # Create a temporary prompt file with frontmatter with tempfile.TemporaryDirectory() as temp_dir: prompt_file = Path(temp_dir) / "test_conversion.prompt" - prompt_file.write_text("""--- + prompt_file.write_text( + """--- model: gpt-4 temperature: 0.7 max_tokens: 200 @@ -379,7 +384,8 @@ output: --- You are an AI assistant. Given the context: {{context}} -Please respond to: {{user_input}}""") +Please respond to: {{user_input}}""" + ) manager = PromptManager() json_data = manager.prompt_file_to_json(prompt_file) diff --git a/tests/test_litellm/integrations/test_custom_guardrail.py b/tests/test_litellm/integrations/test_custom_guardrail.py index ff6c87132a..d09c4ac2c3 100644 --- a/tests/test_litellm/integrations/test_custom_guardrail.py +++ b/tests/test_litellm/integrations/test_custom_guardrail.py @@ -1086,12 +1086,9 @@ class TestCustomGuardrailSpendLogMatchRedaction: ][0]["match"] == "[REDACTED]" ) - assert ( - raw["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"][0][ - "match" - ] - == "GG" - ) + assert raw["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"][0][ + "match" + ] == "GG" def test_add_standard_logging_redacts_regex_field(self): cg = CustomGuardrail(guardrail_name="test-rail") diff --git a/tests/test_litellm/integrations/websearch_interception/test_websearch_thinking_constraint.py b/tests/test_litellm/integrations/websearch_interception/test_websearch_thinking_constraint.py index 32de9a85cf..a939951c43 100644 --- a/tests/test_litellm/integrations/websearch_interception/test_websearch_thinking_constraint.py +++ b/tests/test_litellm/integrations/websearch_interception/test_websearch_thinking_constraint.py @@ -18,6 +18,7 @@ from litellm.integrations.websearch_interception.handler import ( WebSearchInterceptionLogger, ) + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/litellm_core_utils/llm_cost_calc/test_azure_assistant_cost_tracking.py b/tests/test_litellm/litellm_core_utils/llm_cost_calc/test_azure_assistant_cost_tracking.py index 3b99167c03..e8bf54f7ff 100644 --- a/tests/test_litellm/litellm_core_utils/llm_cost_calc/test_azure_assistant_cost_tracking.py +++ b/tests/test_litellm/litellm_core_utils/llm_cost_calc/test_azure_assistant_cost_tracking.py @@ -3,7 +3,7 @@ Test Azure OpenAI Assistant Features Cost Tracking Tests cost calculation for Azure's new assistant features: - File Search (storage-based pricing) -- Code Interpreter (session-based pricing) +- Code Interpreter (session-based pricing) - Computer Use (token-based pricing) - Vector Store (storage-based pricing) """ diff --git a/tests/test_litellm/litellm_core_utils/test_core_helpers.py b/tests/test_litellm/litellm_core_utils/test_core_helpers.py index fc3308d6e3..b67ea91bb0 100644 --- a/tests/test_litellm/litellm_core_utils/test_core_helpers.py +++ b/tests/test_litellm/litellm_core_utils/test_core_helpers.py @@ -176,11 +176,7 @@ class TestRedactNestedMatchAndRegexKeys: { "sensitiveInformationPolicy": { "piiEntities": [ - { - "type": "NAME", - "match": "secret-name", - "action": "BLOCKED", - } + {"type": "NAME", "match": "secret-name", "action": "BLOCKED"} ] }, "wordPolicy": { @@ -191,22 +187,16 @@ class TestRedactNestedMatchAndRegexKeys: "regex": "should-redact-key-named-regex", } out = redact_nested_match_and_regex_keys(payload) - assert ( - out["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"][0][ - "match" - ] - == "[REDACTED]" - ) + assert out["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"][0][ + "match" + ] == "[REDACTED]" assert out["assessments"][0]["wordPolicy"]["customWords"][0]["match"] == ( "[REDACTED]" ) assert out["regex"] == "[REDACTED]" - assert ( - payload["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"][0][ - "match" - ] - == "secret-name" - ) + assert payload["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"][ + 0 + ]["match"] == "secret-name" def test_passes_through_none_and_str(self): assert redact_nested_match_and_regex_keys(None) is None diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index fe6312e572..49d3c51e34 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -2036,19 +2036,23 @@ async def test_azure_streaming_role_preserved_with_include_usage(sync_mode: bool chunks.append(chunk) # The prompt_filter chunk should be forwarded with choices=[] - assert ( - len(chunks[0].choices) == 0 - ), f"Expected prompt_filter chunk with choices=[], got {len(chunks[0].choices)} choices" + assert len(chunks[0].choices) == 0, ( + f"Expected prompt_filter chunk with choices=[], got {len(chunks[0].choices)} choices" + ) # At least one chunk must have role='assistant' in its delta has_role = any( - len(c.choices) > 0 and getattr(c.choices[0].delta, "role", None) == "assistant" + len(c.choices) > 0 + and getattr(c.choices[0].delta, "role", None) == "assistant" for c in chunks ) assert has_role, ( "No chunk contained role='assistant' in delta (issue #24221). " "Chunk deltas: " - + str([c.choices[0].delta if c.choices else "no choices" for c in chunks]) + + str([ + c.choices[0].delta if c.choices else "no choices" + for c in chunks + ]) ) diff --git a/tests/test_litellm/litellm_core_utils/test_token_counter.py b/tests/test_litellm/litellm_core_utils/test_token_counter.py index fa7c9e9aea..3aa5f01246 100644 --- a/tests/test_litellm/litellm_core_utils/test_token_counter.py +++ b/tests/test_litellm/litellm_core_utils/test_token_counter.py @@ -493,6 +493,7 @@ from unittest.mock import MagicMock, patch from litellm.utils import _select_tokenizer_helper, claude_json_str, encoding + # Clear the cache at module load to ensure clean state _select_tokenizer_helper.cache_clear() diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py index 3fc0dea474..42efde9092 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py @@ -2327,8 +2327,7 @@ class TestAnthropicStreamWrapperToolArgs: def _find_tool_deltas(self, events): return [ - e - for e in events + e for e in events if isinstance(e, dict) and e.get("type") == "content_block_delta" and isinstance(e.get("delta"), dict) @@ -2376,6 +2375,7 @@ class TestAnthropicStreamWrapperToolArgs: assert parsed == {"city": "Tokyo"} + def test_translate_anthropic_tool_choice_none(): """ Regression test for issue #24443. diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_agentic_streaming_iterator.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_agentic_streaming_iterator.py index d88260629e..b9bda07336 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_agentic_streaming_iterator.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_agentic_streaming_iterator.py @@ -22,6 +22,7 @@ from litellm.llms.anthropic.experimental_pass_through.messages.agentic_streaming _parse_sse_events, ) + # --------------------------------------------------------------------------- # Helpers to build SSE byte payloads # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_parallel_tool_calls.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_parallel_tool_calls.py index c2938c5b20..1d25d71938 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_parallel_tool_calls.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_parallel_tool_calls.py @@ -2,6 +2,7 @@ import os import sys from typing import List + sys.path.insert(0, os.path.abspath("../../../../..")) from litellm.llms.anthropic.experimental_pass_through.adapters.streaming_iterator import ( diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_reasoning_auto_summary_messages.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_reasoning_auto_summary_messages.py index dd5d134171..07c0012b04 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_reasoning_auto_summary_messages.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_reasoning_auto_summary_messages.py @@ -31,16 +31,13 @@ def _call_handler_and_capture_optional_params(thinking=None, **extra_kwargs): """ captured = {} - with ( - patch( - "litellm.llms.anthropic.experimental_pass_through.messages.handler." - "base_llm_http_handler" - ) as mock_handler, - patch( - "litellm.llms.anthropic.experimental_pass_through.messages.handler." - "ProviderConfigManager" - ) as mock_pcm, - ): + with patch( + "litellm.llms.anthropic.experimental_pass_through.messages.handler." + "base_llm_http_handler" + ) as mock_handler, patch( + "litellm.llms.anthropic.experimental_pass_through.messages.handler." + "ProviderConfigManager" + ) as mock_pcm: # Make get_provider_anthropic_messages_config return a non-None config # so the handler takes the native Anthropic path mock_pcm.get_provider_anthropic_messages_config.return_value = MagicMock() @@ -122,9 +119,8 @@ class TestReasoningAutoSummaryMessages: def test_env_var_enables_auto_summary(self): """LITELLM_REASONING_AUTO_SUMMARY=true env var enables the feature.""" - with ( - patch.object(litellm, "reasoning_auto_summary", False), - patch.dict(os.environ, {"LITELLM_REASONING_AUTO_SUMMARY": "true"}), + with patch.object(litellm, "reasoning_auto_summary", False), patch.dict( + os.environ, {"LITELLM_REASONING_AUTO_SUMMARY": "true"} ): params = _call_handler_and_capture_optional_params( thinking={"type": "adaptive", "budget_tokens": 5000} diff --git a/tests/test_litellm/llms/azure/test_azure_cost_calculation.py b/tests/test_litellm/llms/azure/test_azure_cost_calculation.py index 89fff01336..53c91032b3 100644 --- a/tests/test_litellm/llms/azure/test_azure_cost_calculation.py +++ b/tests/test_litellm/llms/azure/test_azure_cost_calculation.py @@ -8,6 +8,7 @@ import litellm from litellm.llms.azure.cost_calculation import cost_per_token from litellm.types.utils import Usage + # Register a test model with tier-specific pricing TEST_MODEL = "test-azure-gpt-4.1" TEST_MODEL_COST = { diff --git a/tests/test_litellm/llms/azure_ai/test_azure_ai_cost_calculator.py b/tests/test_litellm/llms/azure_ai/test_azure_ai_cost_calculator.py index b9316fb9af..20260c744f 100644 --- a/tests/test_litellm/llms/azure_ai/test_azure_ai_cost_calculator.py +++ b/tests/test_litellm/llms/azure_ai/test_azure_ai_cost_calculator.py @@ -459,21 +459,18 @@ class TestAzureAIServiceTierCostCalculation: @pytest.fixture(autouse=True) def register_test_model(self): import litellm - - litellm.register_model( - model_cost={ - "test-azure-ai-model": { - "input_cost_per_token": 0.001, - "output_cost_per_token": 0.002, - "input_cost_per_token_priority": 0.01, - "output_cost_per_token_priority": 0.02, - "input_cost_per_token_flex": 0.0005, - "output_cost_per_token_flex": 0.001, - "litellm_provider": "azure_ai", - "max_tokens": 8192, - } + litellm.register_model(model_cost={ + "test-azure-ai-model": { + "input_cost_per_token": 0.001, + "output_cost_per_token": 0.002, + "input_cost_per_token_priority": 0.01, + "output_cost_per_token_priority": 0.02, + "input_cost_per_token_flex": 0.0005, + "output_cost_per_token_flex": 0.001, + "litellm_provider": "azure_ai", + "max_tokens": 8192, } - ) + }) def test_service_tier_priority_higher_cost(self): """Priority tier should cost more than standard for azure_ai.""" diff --git a/tests/test_litellm/llms/base_llm/test_managed_resource_isolation.py b/tests/test_litellm/llms/base_llm/test_managed_resource_isolation.py index 1efd10e921..b5fcd9d821 100644 --- a/tests/test_litellm/llms/base_llm/test_managed_resource_isolation.py +++ b/tests/test_litellm/llms/base_llm/test_managed_resource_isolation.py @@ -10,6 +10,7 @@ from litellm.llms.base_llm.managed_resources.isolation import ( ) from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + # --------------------------------------------------------------------------- # build_owner_filter # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/llms/bedrock/chat/test_invoke_handler.py b/tests/test_litellm/llms/bedrock/chat/test_invoke_handler.py index 100f91c4e5..a415d55021 100644 --- a/tests/test_litellm/llms/bedrock/chat/test_invoke_handler.py +++ b/tests/test_litellm/llms/bedrock/chat/test_invoke_handler.py @@ -1,6 +1,7 @@ import os import sys + sys.path.insert( 0, os.path.abspath("../../../../..") ) # Adds the parent directory to the system path diff --git a/tests/test_litellm/llms/crusoe/test_crusoe.py b/tests/test_litellm/llms/crusoe/test_crusoe.py index 32a2f174f3..0a05126919 100644 --- a/tests/test_litellm/llms/crusoe/test_crusoe.py +++ b/tests/test_litellm/llms/crusoe/test_crusoe.py @@ -39,10 +39,7 @@ def test_crusoe_dynamic_config_env_vars(): with patch.dict( os.environ, - { - "CRUSOE_API_KEY": "test-key", - "CRUSOE_API_BASE": "https://custom.crusoe.com/v1", - }, + {"CRUSOE_API_KEY": "test-key", "CRUSOE_API_BASE": "https://custom.crusoe.com/v1"}, ): api_base, api_key = config._get_openai_compatible_provider_info(None, None) @@ -72,9 +69,7 @@ def test_crusoe_supported_params(): from litellm.llms.openai_like.json_loader import JSONProviderRegistry config = create_config_class(JSONProviderRegistry.get("crusoe"))() - params = config.get_supported_openai_params( - model="meta-llama/Llama-3.3-70B-Instruct" - ) + params = config.get_supported_openai_params(model="meta-llama/Llama-3.3-70B-Instruct") assert isinstance(params, list) assert len(params) > 0 @@ -96,9 +91,7 @@ def test_crusoe_param_mapping_max_completion_tokens(): drop_params=False, ) - assert ( - "max_tokens" in optional_params - ), "max_completion_tokens should be mapped to max_tokens" + assert "max_tokens" in optional_params, "max_completion_tokens should be mapped to max_tokens" assert optional_params["max_tokens"] == 1024 assert "max_completion_tokens" not in optional_params diff --git a/tests/test_litellm/llms/custom_httpx/test_aiohttp_so_keepalive.py b/tests/test_litellm/llms/custom_httpx/test_aiohttp_so_keepalive.py index 660ab88c2b..5515e6ce81 100644 --- a/tests/test_litellm/llms/custom_httpx/test_aiohttp_so_keepalive.py +++ b/tests/test_litellm/llms/custom_httpx/test_aiohttp_so_keepalive.py @@ -158,7 +158,4 @@ def test_socket_factory_uses_tcp_keepalive_when_keepidle_unavailable(monkeypatch assert ( setsockopt_calls[(socket.IPPROTO_TCP, fake_socket_module.TCP_KEEPALIVE)] == 60 ) - assert ( - socket.IPPROTO_TCP, - getattr(socket, "TCP_KEEPIDLE", -1), - ) not in setsockopt_calls + assert (socket.IPPROTO_TCP, getattr(socket, "TCP_KEEPIDLE", -1)) not in setsockopt_calls diff --git a/tests/test_litellm/llms/custom_httpx/test_mock_transport.py b/tests/test_litellm/llms/custom_httpx/test_mock_transport.py index b3b5f82f02..c2d4e14642 100644 --- a/tests/test_litellm/llms/custom_httpx/test_mock_transport.py +++ b/tests/test_litellm/llms/custom_httpx/test_mock_transport.py @@ -10,6 +10,7 @@ import pytest from litellm.llms.custom_httpx.mock_transport import MockOpenAITransport + # --------------------------------------------------------------------------- # Non-streaming # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/llms/databricks/test_databricks_e2e.py b/tests/test_litellm/llms/databricks/test_databricks_e2e.py index 90846f2892..669f9e9463 100644 --- a/tests/test_litellm/llms/databricks/test_databricks_e2e.py +++ b/tests/test_litellm/llms/databricks/test_databricks_e2e.py @@ -17,13 +17,13 @@ Purpose: LiteLLM Integration Tests: This test file includes tests for different ways of calling Databricks via LiteLLM: - + 1. LiteLLM SDK Direct - Using litellm.completion() with user_agent parameter 2. LangChain + LiteLLM - Using ChatLiteLLM wrapper (requires langchain-community) 3. LiteLLM Async - Using litellm.acompletion() async API 4. LiteLLM Streaming - Using litellm.completion() with stream=True 5. LiteLLM Embedding - Using litellm.embedding() with user_agent parameter - + All tests use the CUSTOM_USER_AGENT value from the config file and call Databricks endpoints through LiteLLM's unified interface. @@ -31,7 +31,7 @@ Prerequisites: - Valid Databricks workspace access - Configured credentials (OAuth Service Principal, PAT, or Databricks CLI) - Access to serving endpoints (e.g., databricks-gpt-oss-120b) - + Optional Dependencies (for LiteLLM integration tests): - pip install langchain-litellm # For LangChain tests (recommended) diff --git a/tests/test_litellm/llms/deepgram/audio_transcription/test_deepgram_audio_transcription_transformation.py b/tests/test_litellm/llms/deepgram/audio_transcription/test_deepgram_audio_transcription_transformation.py index a950b4a5f2..d59ab975ef 100644 --- a/tests/test_litellm/llms/deepgram/audio_transcription/test_deepgram_audio_transcription_transformation.py +++ b/tests/test_litellm/llms/deepgram/audio_transcription/test_deepgram_audio_transcription_transformation.py @@ -53,7 +53,7 @@ def test_file(): ) def test_audio_file_handling(fixture_name, request): handler = DeepgramAudioTranscriptionConfig() - audio_file, expected_output = request.getfixturevalue(fixture_name) + (audio_file, expected_output) = request.getfixturevalue(fixture_name) result = handler.transform_audio_transcription_request( model="deepseek-audio-transcription", audio_file=audio_file, diff --git a/tests/test_litellm/llms/github_copilot/test_github_copilot_authenticator.py b/tests/test_litellm/llms/github_copilot/test_github_copilot_authenticator.py index 7c787f20b8..6c846a90c7 100644 --- a/tests/test_litellm/llms/github_copilot/test_github_copilot_authenticator.py +++ b/tests/test_litellm/llms/github_copilot/test_github_copilot_authenticator.py @@ -255,19 +255,12 @@ class TestGitHubCopilotAuthenticator: "user_code": "UC", "verification_uri": "https://example.com", } - with ( - patch.dict(os.environ, {"GITHUB_COPILOT_DEVICE_CODE_URL": custom_url}), - patch( - "litellm.llms.github_copilot.authenticator._get_httpx_client", - return_value=mock_client, - ), - ): + with patch.dict(os.environ, {"GITHUB_COPILOT_DEVICE_CODE_URL": custom_url}), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client): authenticator._get_device_code() assert mock_client.post.call_args[0][0] == custom_url - def test_get_device_code_with_custom_client_id( - self, authenticator, mock_http_client - ): + def test_get_device_code_with_custom_client_id(self, authenticator, mock_http_client): """GITHUB_COPILOT_CLIENT_ID env var must appear as client_id in the device-code request body.""" mock_client, mock_response = mock_http_client custom_id = "custom_client_id" @@ -276,49 +269,30 @@ class TestGitHubCopilotAuthenticator: "user_code": "UC", "verification_uri": "https://example.com", } - with ( - patch.dict(os.environ, {"GITHUB_COPILOT_CLIENT_ID": custom_id}), - patch( - "litellm.llms.github_copilot.authenticator._get_httpx_client", - return_value=mock_client, - ), - ): + with patch.dict(os.environ, {"GITHUB_COPILOT_CLIENT_ID": custom_id}), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client): authenticator._get_device_code() assert mock_client.post.call_args[1]["json"]["client_id"] == custom_id - def test_poll_for_access_token_with_custom_url( - self, authenticator, mock_http_client - ): + def test_poll_for_access_token_with_custom_url(self, authenticator, mock_http_client): """GITHUB_COPILOT_ACCESS_TOKEN_URL env var must be used by _poll_for_access_token at call time.""" mock_client, mock_response = mock_http_client custom_url = "https://custom.example.com/token" mock_response.json.return_value = {"access_token": "tok"} - with ( - patch.dict(os.environ, {"GITHUB_COPILOT_ACCESS_TOKEN_URL": custom_url}), - patch( - "litellm.llms.github_copilot.authenticator._get_httpx_client", - return_value=mock_client, - ), - patch("time.sleep"), - ): + with patch.dict(os.environ, {"GITHUB_COPILOT_ACCESS_TOKEN_URL": custom_url}), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch("time.sleep"): authenticator._poll_for_access_token("dc") assert mock_client.post.call_args[0][0] == custom_url - def test_poll_for_access_token_with_custom_client_id( - self, authenticator, mock_http_client - ): + def test_poll_for_access_token_with_custom_client_id(self, authenticator, mock_http_client): """GITHUB_COPILOT_CLIENT_ID env var must appear as client_id in the polling request body.""" mock_client, mock_response = mock_http_client custom_id = "custom_client_id" mock_response.json.return_value = {"access_token": "tok"} - with ( - patch.dict(os.environ, {"GITHUB_COPILOT_CLIENT_ID": custom_id}), - patch( - "litellm.llms.github_copilot.authenticator._get_httpx_client", - return_value=mock_client, - ), - patch("time.sleep"), - ): + with patch.dict(os.environ, {"GITHUB_COPILOT_CLIENT_ID": custom_id}), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch("time.sleep"): authenticator._poll_for_access_token("dc") assert mock_client.post.call_args[1]["json"]["client_id"] == custom_id @@ -327,13 +301,9 @@ class TestGitHubCopilotAuthenticator: mock_client, mock_response = mock_http_client custom_url = "https://custom.example.com/api-key" mock_response.json.return_value = {"token": "api-tok", "expires_at": 9999999999} - with ( - patch.dict(os.environ, {"GITHUB_COPILOT_API_KEY_URL": custom_url}), - patch( - "litellm.llms.github_copilot.authenticator._get_httpx_client", - return_value=mock_client, - ), - patch.object(authenticator, "get_access_token", return_value="access-tok"), - ): + with patch.dict(os.environ, {"GITHUB_COPILOT_API_KEY_URL": custom_url}), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(authenticator, "get_access_token", return_value="access-tok"): authenticator._refresh_api_key() assert mock_client.get.call_args[0][0] == custom_url + diff --git a/tests/test_litellm/llms/huggingface/embedding/test_huggingface_embedding_handler.py b/tests/test_litellm/llms/huggingface/embedding/test_huggingface_embedding_handler.py index e6acc51a58..560796ea58 100644 --- a/tests/test_litellm/llms/huggingface/embedding/test_huggingface_embedding_handler.py +++ b/tests/test_litellm/llms/huggingface/embedding/test_huggingface_embedding_handler.py @@ -11,6 +11,7 @@ sys.path.insert( import litellm import pytest + MOCK_EMBEDDING_RESPONSE = [[0.1, 0.2, 0.3, 0.4, 0.5]] diff --git a/tests/test_litellm/llms/moonshot/test_moonshot_chat_transformation.py b/tests/test_litellm/llms/moonshot/test_moonshot_chat_transformation.py index e398699bfd..b4744a7ed1 100644 --- a/tests/test_litellm/llms/moonshot/test_moonshot_chat_transformation.py +++ b/tests/test_litellm/llms/moonshot/test_moonshot_chat_transformation.py @@ -666,9 +666,7 @@ class TestKimiK26ModelRegistry: def test_kimi_k26_in_model_cost_map(self, model_cost_map): """kimi-k2.6 should be present in the model cost map.""" - assert ( - "moonshot/kimi-k2.6" in model_cost_map - ), "moonshot/kimi-k2.6 not found in model_cost" + assert "moonshot/kimi-k2.6" in model_cost_map, "moonshot/kimi-k2.6 not found in model_cost" def test_kimi_k26_pricing(self, model_cost_map): """kimi-k2.6 pricing should match official Kimi API rates.""" diff --git a/tests/test_litellm/llms/perplexity/test_perplexity_cost_calculator.py b/tests/test_litellm/llms/perplexity/test_perplexity_cost_calculator.py index c6971cfd55..d408f55c00 100644 --- a/tests/test_litellm/llms/perplexity/test_perplexity_cost_calculator.py +++ b/tests/test_litellm/llms/perplexity/test_perplexity_cost_calculator.py @@ -1,7 +1,7 @@ """ Test file for Perplexity cost calculator functionality. -Tests the cost calculation for Perplexity models including citation tokens, +Tests the cost calculation for Perplexity models including citation tokens, search queries, and reasoning tokens. """ diff --git a/tests/test_litellm/llms/perplexity/test_perplexity_integration.py b/tests/test_litellm/llms/perplexity/test_perplexity_integration.py index cd76fcc93c..1b03fd7df8 100644 --- a/tests/test_litellm/llms/perplexity/test_perplexity_integration.py +++ b/tests/test_litellm/llms/perplexity/test_perplexity_integration.py @@ -1,7 +1,7 @@ """ Integration tests for Perplexity cost calculation and transformation. -Tests the end-to-end functionality of Perplexity cost calculation +Tests the end-to-end functionality of Perplexity cost calculation including integration with the main LiteLLM cost calculator. """ diff --git a/tests/test_litellm/llms/scaleway/test_scaleway_audio_transcription_transformation.py b/tests/test_litellm/llms/scaleway/test_scaleway_audio_transcription_transformation.py index c237e28081..407e1d19fb 100644 --- a/tests/test_litellm/llms/scaleway/test_scaleway_audio_transcription_transformation.py +++ b/tests/test_litellm/llms/scaleway/test_scaleway_audio_transcription_transformation.py @@ -10,6 +10,7 @@ from litellm.llms.scaleway.audio_transcription.transformation import ( ) from litellm.types.utils import TranscriptionResponse + # --------------------------------------------------------------------------- # get_complete_url # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/llms/test_polling_url_origin_match.py b/tests/test_litellm/llms/test_polling_url_origin_match.py index b6b932ad93..f1f910bc73 100644 --- a/tests/test_litellm/llms/test_polling_url_origin_match.py +++ b/tests/test_litellm/llms/test_polling_url_origin_match.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest + # Azure DALL-E sync + async paths route through ``assert_same_origin`` # the same way as the cases below. The helper itself is unit-tested in # ``tests/test_litellm/litellm_core_utils/test_url_utils.py``; the diff --git a/tests/test_litellm/llms/test_predibase_transformation.py b/tests/test_litellm/llms/test_predibase_transformation.py index 7d67996906..1600878a58 100644 --- a/tests/test_litellm/llms/test_predibase_transformation.py +++ b/tests/test_litellm/llms/test_predibase_transformation.py @@ -233,9 +233,7 @@ def test_predibase_transform_response_non_dict_payload(): raw_response.headers = {} raw_response.json.return_value = [] - with pytest.raises( - PredibaseError, match="'completion_response' is not a dictionary" - ): + with pytest.raises(PredibaseError, match="'completion_response' is not a dictionary"): config.transform_response( model="predibase-model", raw_response=raw_response, @@ -377,9 +375,7 @@ def test_predibase_transform_response_best_of_invalid_value_falls_back(monkeypat assert result.choices[0].message.content == "primary-output" -def test_predibase_transform_response_empty_output_sets_completion_tokens_zero( - monkeypatch, -): +def test_predibase_transform_response_empty_output_sets_completion_tokens_zero(monkeypatch): config = PredibaseConfig() logging_obj = Mock() encoding = Mock() @@ -433,10 +429,7 @@ def test_predibase_transform_response_usage_fallbacks(monkeypatch): raw_response = httpx.Response( status_code=200, - json={ - "generated_text": "ok", - "details": {"tokens": [], "finish_reason": "stop"}, - }, + json={"generated_text": "ok", "details": {"tokens": [], "finish_reason": "stop"}}, ) result = config.transform_response( @@ -546,17 +539,13 @@ def test_predibase_completion_sync_returns_transform_response(monkeypatch): def fake_transform_response(self, **kwargs): return expected - monkeypatch.setattr( - PredibaseConfig, "validate_environment", fake_validate_environment - ) + monkeypatch.setattr(PredibaseConfig, "validate_environment", fake_validate_environment) monkeypatch.setattr(PredibaseConfig, "get_complete_url", fake_get_complete_url) monkeypatch.setattr(PredibaseConfig, "transform_request", fake_transform_request) monkeypatch.setattr(PredibaseConfig, "transform_response", fake_transform_response) monkeypatch.setattr( "litellm.module_level_client.post", - lambda *args, **kwargs: httpx.Response( - status_code=200, json={"generated_text": "ok"} - ), + lambda *args, **kwargs: httpx.Response(status_code=200, json={"generated_text": "ok"}), ) result = handler.completion( @@ -597,9 +586,7 @@ def test_predibase_completion_passes_existing_config_to_async_completion(monkeyp captured["async_kwargs"] = kwargs return "async-result" - monkeypatch.setattr( - PredibaseConfig, "validate_environment", fake_validate_environment - ) + monkeypatch.setattr(PredibaseConfig, "validate_environment", fake_validate_environment) monkeypatch.setattr(PredibaseConfig, "get_complete_url", fake_get_complete_url) monkeypatch.setattr(PredibaseConfig, "transform_request", fake_transform_request) monkeypatch.setattr(handler, "async_completion", fake_async_completion) diff --git a/tests/test_litellm/llms/vertex_ai/gemini/test_context_circulation.py b/tests/test_litellm/llms/vertex_ai/gemini/test_context_circulation.py index ef084dcc00..6d913ad5d1 100644 --- a/tests/test_litellm/llms/vertex_ai/gemini/test_context_circulation.py +++ b/tests/test_litellm/llms/vertex_ai/gemini/test_context_circulation.py @@ -22,6 +22,7 @@ from litellm.llms.vertex_ai.gemini.transformation import ( ) from litellm.types.llms.vertex_ai import HttpxPartType + # --- Response extraction tests --- diff --git a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py index fe97727bb8..353d19b019 100644 --- a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py +++ b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py @@ -3499,12 +3499,7 @@ def test_video_metadata_supported_for_all_gemini_models(): } ] - for model in [ - "gemini-1.5-pro", - "gemini-2.5-flash", - "gemini-2.5-pro", - "gemini-3-pro-preview", - ]: + for model in ["gemini-1.5-pro", "gemini-2.5-flash", "gemini-2.5-pro", "gemini-3-pro-preview"]: contents = _gemini_convert_messages_with_history(messages=messages, model=model) file_part = None @@ -3514,25 +3509,19 @@ def test_video_metadata_supported_for_all_gemini_models(): break assert file_part is not None, f"{model}: file part should exist" - assert ( - "video_metadata" in file_part - ), f"{model}: video_metadata should be present" + assert "video_metadata" in file_part, f"{model}: video_metadata should be present" assert file_part["video_metadata"]["fps"] == 5, f"{model}: fps should be 5" # Per-part media_resolution is Gemini 3+ only; 2.x uses generation_config global for model in ["gemini-3-pro-preview"]: contents = _gemini_convert_messages_with_history(messages=messages, model=model) file_part = next(p for p in contents[0]["parts"] if "file_data" in p) - assert ( - "media_resolution" in file_part - ), f"{model}: media_resolution should be present" + assert "media_resolution" in file_part, f"{model}: media_resolution should be present" for model in ["gemini-1.5-pro", "gemini-2.5-flash", "gemini-2.5-pro"]: contents = _gemini_convert_messages_with_history(messages=messages, model=model) file_part = next(p for p in contents[0]["parts"] if "file_data" in p) - assert ( - "media_resolution" not in file_part - ), f"{model}: per-part media_resolution should not be set" + assert "media_resolution" not in file_part, f"{model}: per-part media_resolution should not be set" def test_chunk_parser_handles_prompt_feedback_block(): @@ -4165,9 +4154,8 @@ def test_vertex_ai_usage_metadata_with_document_tokens_in_prompt(): # DOCUMENT tokens should be included in text_tokens: 8 (TEXT) + 774 (DOCUMENT) = 782 assert result.prompt_tokens_details is not None - assert ( - result.prompt_tokens_details.text_tokens == 782 - ), "DOCUMENT modality tokens should be added to text_tokens (8 TEXT + 774 DOCUMENT = 782)" + assert result.prompt_tokens_details.text_tokens == 782, \ + "DOCUMENT modality tokens should be added to text_tokens (8 TEXT + 774 DOCUMENT = 782)" # Verify completion token details assert result.completion_tokens_details is not None @@ -4202,9 +4190,8 @@ def test_vertex_ai_usage_metadata_with_document_tokens_cached(): # DOCUMENT cached tokens map to cached_text_tokens, so: # text_tokens = (8 TEXT + 774 DOCUMENT) - 400 cached = 382 - assert ( - result.prompt_tokens_details.text_tokens == 382 - ), "text_tokens should be (8 + 774) - 400 cached = 382" + assert result.prompt_tokens_details.text_tokens == 382, \ + "text_tokens should be (8 + 774) - 400 cached = 382" assert result.prompt_tokens_details.cached_tokens == 400 diff --git a/tests/test_litellm/llms/vertex_ai/gemini_embeddings/test_batch_embed_content_transformation.py b/tests/test_litellm/llms/vertex_ai/gemini_embeddings/test_batch_embed_content_transformation.py index 4bd0b93027..bb4e6c67e9 100644 --- a/tests/test_litellm/llms/vertex_ai/gemini_embeddings/test_batch_embed_content_transformation.py +++ b/tests/test_litellm/llms/vertex_ai/gemini_embeddings/test_batch_embed_content_transformation.py @@ -20,6 +20,7 @@ from litellm.llms.vertex_ai.gemini_embeddings.batch_embed_content_transformation from litellm.types.llms.vertex_ai import VertexAIBatchEmbeddingsResponseObject from litellm.types.utils import EmbeddingResponse + IMAGE_DATA_URI = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII" GCS_URL = "gs://my-bucket/image.png" @@ -71,9 +72,7 @@ class TestBuildPartForInput: assert part["file_data"]["file_uri"] == GCS_URL def test_file_reference_resolved(self): - resolved = { - "files/abc": {"mime_type": "image/jpeg", "uri": "https://example.com/abc"} - } + resolved = {"files/abc": {"mime_type": "image/jpeg", "uri": "https://example.com/abc"}} part = _build_part_for_input("files/abc", resolved_files=resolved) assert part["file_data"] is not None assert part["file_data"]["mime_type"] == "image/jpeg" @@ -95,9 +94,7 @@ class TestTransformOpenaiInputGeminiContent: def test_multiple_texts(self): result = transform_openai_input_gemini_content( - input=["hello", "world"], - model="gemini-embedding-2-preview", - optional_params={}, + input=["hello", "world"], model="gemini-embedding-2-preview", optional_params={} ) assert len(result["requests"]) == 2 assert result["requests"][0]["content"]["parts"][0]["text"] == "hello" @@ -112,10 +109,7 @@ class TestTransformOpenaiInputGeminiContent: ) assert len(result["requests"]) == 2 # First request is text - assert ( - result["requests"][0]["content"]["parts"][0]["text"] - == "The food was delicious" - ) + assert result["requests"][0]["content"]["parts"][0]["text"] == "The food was delicious" # Second request is image assert result["requests"][1]["content"]["parts"][0]["inline_data"] is not None diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_common_utils.py b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_common_utils.py index b47f8c6af4..6c549af2cc 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_common_utils.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_common_utils.py @@ -1469,7 +1469,11 @@ def test_vertex_request_labels_from_litellm_params_extracts_requester_metadata() def test_vertex_request_labels_from_litellm_params_accepts_litellm_metadata(): - lp = {"litellm_metadata": {"requester_metadata": {"team": "platform", "count": 3}}} + lp = { + "litellm_metadata": { + "requester_metadata": {"team": "platform", "count": 3} + } + } assert vertex_request_labels_from_litellm_params(lp) == {"team": "platform"} diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_model_garden_openapi.py b/tests/test_litellm/llms/vertex_ai/test_vertex_model_garden_openapi.py index 80fe002176..91261b6325 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_model_garden_openapi.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_model_garden_openapi.py @@ -37,8 +37,5 @@ def test_create_vertex_url_openapi_vs_deployed_endpoint( def test_model_id_in_json_body_heuristic() -> None: - assert ( - _vertex_model_garden_model_id_in_json_body("xai/grok-4.1-fast-reasoning") - is True - ) + assert _vertex_model_garden_model_id_in_json_body("xai/grok-4.1-fast-reasoning") is True assert _vertex_model_garden_model_id_in_json_body("5464397967697903616") is False diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py index 3a5ea4e88c..6e0dadcd4d 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py @@ -880,7 +880,7 @@ class TestMCPPublicRouteGuard: with patch( "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", ) as mock_auth: - auth_result, *_rest = await MCPRequestHandler.process_mcp_request(scope) + (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) mock_auth.assert_not_called() assert isinstance(auth_result, UserAPIKeyAuth) @@ -997,7 +997,7 @@ class TestMCPOAuth2FallbackTargetGating: mock_mgr.get_mcp_server_by_name.return_value = ( TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.oauth2) ) - auth_result, *_rest = await MCPRequestHandler.process_mcp_request(scope) + (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) assert isinstance(auth_result, UserAPIKeyAuth) async def test_fallback_blocked_when_any_target_in_header_is_not_oauth2(self): diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py index 37e593e2a0..078adf72d4 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py @@ -25,6 +25,7 @@ from litellm.proxy._experimental.mcp_server.db import ( ) from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper + SALT_KEY = "test-salt-key-for-byok-credential-tests-1234" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_is_tool_name_prefixed.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_is_tool_name_prefixed.py index 1084e64eab..8f09e2410c 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_is_tool_name_prefixed.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_is_tool_name_prefixed.py @@ -8,6 +8,7 @@ import pytest from litellm.proxy._experimental.mcp_server.utils import is_tool_name_prefixed + # --------------------------------------------------------------------------- # Legacy behaviour (no known_server_prefixes passed) # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_stale_session.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_stale_session.py index 4cc764f3c6..a0bfbff422 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_stale_session.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_stale_session.py @@ -497,39 +497,31 @@ async def test_per_user_oauth_missing_stored_token_returns_preemptive_401(): oauth_server.auth_type = MCPAuth.oauth2 oauth_server.needs_user_oauth_token = True - with ( - patch( - "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", - new_callable=AsyncMock, - return_value=(user_auth, None, ["repro_oauth_server"], None, None, None), - ), - patch( - "litellm.proxy._experimental.mcp_server.server.set_auth_context", - ), - patch( - "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", - True, - ), - patch( - "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", - new_callable=AsyncMock, - return_value=False, - ), - patch( - "litellm.proxy._experimental.mcp_server.server._get_user_oauth_extra_headers_from_db", - new_callable=AsyncMock, - return_value=None, - ) as mock_get_stored_token, - patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", - return_value=oauth_server, - ), - patch.object( - session_manager, - "handle_request", - new_callable=AsyncMock, - ) as mock_handle_request, - ): + with patch( + "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", + new_callable=AsyncMock, + return_value=(user_auth, None, ["repro_oauth_server"], None, None, None), + ), patch( + "litellm.proxy._experimental.mcp_server.server.set_auth_context", + ), patch( + "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", + True, + ), patch( + "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", + new_callable=AsyncMock, + return_value=False, + ), patch( + "litellm.proxy._experimental.mcp_server.server._get_user_oauth_extra_headers_from_db", + new_callable=AsyncMock, + return_value=None, + ) as mock_get_stored_token, patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", + return_value=oauth_server, + ), patch.object( + session_manager, + "handle_request", + new_callable=AsyncMock, + ) as mock_handle_request: with pytest.raises(HTTPException) as exc_info: await handle_streamable_http_mcp(scope, receive, send) @@ -570,39 +562,31 @@ async def test_per_user_oauth_with_stored_token_skips_preemptive_401(): oauth_server.auth_type = MCPAuth.oauth2 oauth_server.needs_user_oauth_token = True - with ( - patch( - "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", - new_callable=AsyncMock, - return_value=(user_auth, None, ["repro_oauth_server"], None, None, None), - ), - patch( - "litellm.proxy._experimental.mcp_server.server.set_auth_context", - ), - patch( - "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", - True, - ), - patch( - "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", - new_callable=AsyncMock, - return_value=False, - ), - patch( - "litellm.proxy._experimental.mcp_server.server._get_user_oauth_extra_headers_from_db", - new_callable=AsyncMock, - return_value={"Authorization": "Bearer cached-token"}, - ) as mock_get_stored_token, - patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", - return_value=oauth_server, - ), - patch.object( - session_manager, - "handle_request", - new_callable=AsyncMock, - ) as mock_handle_request, - ): + with patch( + "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", + new_callable=AsyncMock, + return_value=(user_auth, None, ["repro_oauth_server"], None, None, None), + ), patch( + "litellm.proxy._experimental.mcp_server.server.set_auth_context", + ), patch( + "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", + True, + ), patch( + "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", + new_callable=AsyncMock, + return_value=False, + ), patch( + "litellm.proxy._experimental.mcp_server.server._get_user_oauth_extra_headers_from_db", + new_callable=AsyncMock, + return_value={"Authorization": "Bearer cached-token"}, + ) as mock_get_stored_token, patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", + return_value=oauth_server, + ), patch.object( + session_manager, + "handle_request", + new_callable=AsyncMock, + ) as mock_handle_request: await handle_streamable_http_mcp(scope, receive, send) assert mock_get_stored_token.await_count == 1 diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py index 2888077a23..2558df8533 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py @@ -489,7 +489,9 @@ class TestGetToolsByNames: {"name": "send_email", "description": "send mail"}, ] - matched = filter_instance._get_tools_by_names(["send_email"], available_tools) + matched = filter_instance._get_tools_by_names( + ["send_email"], available_tools + ) assert len(matched) == 1 assert matched[0]["name"] == "send_email" @@ -501,7 +503,9 @@ class TestGetToolsByNames: client_name = "litellm_" + canonical available_tools = [{"name": client_name, "description": "scrape"}] - matched = filter_instance._get_tools_by_names([canonical], available_tools) + matched = filter_instance._get_tools_by_names( + [canonical], available_tools + ) assert len(matched) == 1 # Must return the incoming tool unchanged so the client-facing @@ -512,9 +516,13 @@ class TestGetToolsByNames: """Some clients use dash as alias separator; accept that too.""" filter_instance = self._make_filter() canonical = "weather_svc-get_weather" - available_tools = [{"name": "mcp-" + canonical, "description": "weather"}] + available_tools = [ + {"name": "mcp-" + canonical, "description": "weather"} + ] - matched = filter_instance._get_tools_by_names([canonical], available_tools) + matched = filter_instance._get_tools_by_names( + [canonical], available_tools + ) assert len(matched) == 1 assert matched[0]["name"] == "mcp-" + canonical @@ -544,7 +552,9 @@ class TestGetToolsByNames: {"name": "litellm_" + canonical, "description": "wrapped"}, ] - matched = filter_instance._get_tools_by_names([canonical], available_tools) + matched = filter_instance._get_tools_by_names( + [canonical], available_tools + ) assert len(matched) == 1 assert matched[0]["name"] == canonical @@ -557,7 +567,9 @@ class TestGetToolsByNames: separator-anchored suffixes of ``litellm_api-fs-read_file``. """ filter_instance = self._make_filter() - available_tools = [{"name": "litellm_api-fs-read_file", "description": "read"}] + available_tools = [ + {"name": "litellm_api-fs-read_file", "description": "read"} + ] matched = filter_instance._get_tools_by_names( ["fs-read_file", "api-fs-read_file"], available_tools @@ -578,7 +590,9 @@ class TestGetToolsByNames: {"name": "my_" + canonical, "description": "plain search"}, ] - matched = filter_instance._get_tools_by_names([canonical], available_tools) + matched = filter_instance._get_tools_by_names( + [canonical], available_tools + ) assert len(matched) == 1 assert matched[0]["name"] == "my_" + canonical diff --git a/tests/test_litellm/proxy/agent_endpoints/test_agent_header_isolation.py b/tests/test_litellm/proxy/agent_endpoints/test_agent_header_isolation.py index 06da009134..554f98d720 100644 --- a/tests/test_litellm/proxy/agent_endpoints/test_agent_header_isolation.py +++ b/tests/test_litellm/proxy/agent_endpoints/test_agent_header_isolation.py @@ -16,6 +16,7 @@ import pytest from litellm.constants import DEFAULT_A2A_AGENT_TIMEOUT + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/agent_endpoints/test_agent_headers.py b/tests/test_litellm/proxy/agent_endpoints/test_agent_headers.py index 7cbd12a4ea..93ba9dc922 100644 --- a/tests/test_litellm/proxy/agent_endpoints/test_agent_headers.py +++ b/tests/test_litellm/proxy/agent_endpoints/test_agent_headers.py @@ -14,6 +14,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest + # --------------------------------------------------------------------------- # Helper: build a minimal mock agent # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/auth/test_onboarding.py b/tests/test_litellm/proxy/auth/test_onboarding.py index d55a5472af..c81f4cb7d6 100644 --- a/tests/test_litellm/proxy/auth/test_onboarding.py +++ b/tests/test_litellm/proxy/auth/test_onboarding.py @@ -18,6 +18,7 @@ from fastapi import HTTPException import litellm from litellm.proxy._types import InvitationClaim + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/common_utils/test_cache_codec.py b/tests/test_litellm/proxy/common_utils/test_cache_codec.py index 8aa16ddbce..044d4c2d1a 100644 --- a/tests/test_litellm/proxy/common_utils/test_cache_codec.py +++ b/tests/test_litellm/proxy/common_utils/test_cache_codec.py @@ -54,9 +54,7 @@ class TestCacheCodecSerialize: def test_with_model_type_already_correct_instance_skips_revalidation(self): """Fast-path: value is already model_type — model_validate must NOT be called.""" m = _SampleModel(name="fast", count=7) - with patch.object( - _SampleModel, "model_validate", wraps=_SampleModel.model_validate - ) as mock_validate: + with patch.object(_SampleModel, "model_validate", wraps=_SampleModel.model_validate) as mock_validate: out = CacheCodec.serialize(m, model_type=_SampleModel) assert out == {"name": "fast", "count": 7} mock_validate.assert_not_called() @@ -64,9 +62,7 @@ class TestCacheCodecSerialize: def test_with_model_type_subclass_instance_skips_revalidation(self): """Subclass is isinstance of base → should also take the fast path.""" sub = _SampleSubModel(name="sub", count=2) - with patch.object( - _SampleModel, "model_validate", wraps=_SampleModel.model_validate - ) as mock_validate: + with patch.object(_SampleModel, "model_validate", wraps=_SampleModel.model_validate) as mock_validate: out = CacheCodec.serialize(sub, model_type=_SampleModel) assert out == {"name": "sub", "count": 2} mock_validate.assert_not_called() diff --git a/tests/test_litellm/proxy/conftest.py b/tests/test_litellm/proxy/conftest.py index d01c47d975..20236ebdf4 100644 --- a/tests/test_litellm/proxy/conftest.py +++ b/tests/test_litellm/proxy/conftest.py @@ -14,6 +14,7 @@ import pytest import yaml from fastapi.testclient import TestClient + _PROXY_MODULE_GLOBALS_TO_ISOLATE = ( "master_key", "prisma_client", diff --git a/tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py b/tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py index 3586c9a80e..434f7953c2 100644 --- a/tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py +++ b/tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py @@ -2,7 +2,6 @@ """ Test to verify the Google GenAI proxy API endpoints """ - import os import sys from unittest.mock import AsyncMock, MagicMock, patch diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/content_filter/test_content_filter.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/content_filter/test_content_filter.py index 399852acd4..fb952d4b18 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/content_filter/test_content_filter.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/content_filter/test_content_filter.py @@ -314,13 +314,15 @@ class TestContentFilterGuardrail: # Create a temporary blocked words file with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("""blocked_words: + f.write( + """blocked_words: - keyword: "test_keyword" action: "BLOCK" description: "Test keyword" - keyword: "another_word" action: "MASK" -""") +""" + ) temp_file = f.name try: diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/openai/test_moderations.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/openai/test_moderations.py index ee7440d119..bccfb4a1cb 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/openai/test_moderations.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/openai/test_moderations.py @@ -2,7 +2,6 @@ """ Test OpenAI Moderation Guardrail """ - import os import sys diff --git a/tests/test_litellm/proxy/guardrails/test_content_utils.py b/tests/test_litellm/proxy/guardrails/test_content_utils.py index d3428dfbaf..099fca78a6 100644 --- a/tests/test_litellm/proxy/guardrails/test_content_utils.py +++ b/tests/test_litellm/proxy/guardrails/test_content_utils.py @@ -8,6 +8,7 @@ from litellm.proxy.guardrails._content_utils import ( walk_user_text, ) + # ── iter_message_text ──────────────────────────────────────────────────────────── diff --git a/tests/test_litellm/proxy/guardrails/test_custom_code_security.py b/tests/test_litellm/proxy/guardrails/test_custom_code_security.py index defdba10a6..00cf3f317c 100644 --- a/tests/test_litellm/proxy/guardrails/test_custom_code_security.py +++ b/tests/test_litellm/proxy/guardrails/test_custom_code_security.py @@ -5,6 +5,7 @@ from litellm.proxy.guardrails.guardrail_hooks.custom_code.custom_code_guardrail CustomCodeGuardrail, ) + # str.mro() + generator gi_code + code.replace(co_names=...) + __setattr__ # to swap a function's bytecode and read http_get's real builtins dict. BYTECODE_REWRITE_PAYLOAD = ( diff --git a/tests/test_litellm/proxy/guardrails/test_llm_as_a_judge.py b/tests/test_litellm/proxy/guardrails/test_llm_as_a_judge.py index 9178a66611..c9fde4ffba 100644 --- a/tests/test_litellm/proxy/guardrails/test_llm_as_a_judge.py +++ b/tests/test_litellm/proxy/guardrails/test_llm_as_a_judge.py @@ -13,6 +13,7 @@ from litellm.proxy.guardrails.guardrail_hooks.llm_as_a_judge import ( initialize_guardrail, ) + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -38,20 +39,8 @@ def _make_guardrail(**overrides) -> LLMAsAJudgeGuardrail: def _make_verdict_response(overall_score: float) -> dict: return { "verdicts": [ - { - "criterion_name": "Accuracy", - "score": overall_score, - "reasoning": "ok", - "passed": True, - "weight": 60, - }, - { - "criterion_name": "Safety", - "score": overall_score, - "reasoning": "ok", - "passed": True, - "weight": 40, - }, + {"criterion_name": "Accuracy", "score": overall_score, "reasoning": "ok", "passed": True, "weight": 60}, + {"criterion_name": "Safety", "score": overall_score, "reasoning": "ok", "passed": True, "weight": 40}, ], "overall_score": overall_score, } @@ -101,15 +90,7 @@ def test_build_judge_prompt_missing_name_and_weight(): def _make_litellm_params(**overrides): params = MagicMock() - for attr in ( - "guardrail_name", - "judge_model", - "criteria", - "on_failure", - "overall_threshold", - "mode", - "default_on", - ): + for attr in ("guardrail_name", "judge_model", "criteria", "on_failure", "overall_threshold", "mode", "default_on"): setattr(params, attr, None) for k, v in overrides.items(): setattr(params, k, v) @@ -117,19 +98,12 @@ def _make_litellm_params(**overrides): def _make_guardrail_dict(name="g", **litellm_params_overrides): - raw = { - "judge_model": "gpt-4o-mini", - "criteria": CRITERIA_100, - "on_failure": "block", - "overall_threshold": 80.0, - } + raw = {"judge_model": "gpt-4o-mini", "criteria": CRITERIA_100, "on_failure": "block", "overall_threshold": 80.0} raw.update(litellm_params_overrides) return {"guardrail_name": name, "litellm_params": raw} -@patch( - "litellm.proxy.guardrails.guardrail_hooks.llm_as_a_judge.litellm.logging_callback_manager" -) +@patch("litellm.proxy.guardrails.guardrail_hooks.llm_as_a_judge.litellm.logging_callback_manager") def test_initialize_guardrail_ok(mock_mgr): lp = _make_litellm_params() g = _make_guardrail_dict() @@ -186,18 +160,11 @@ async def test_apply_guardrail_empty_response_passthrough(): @patch("litellm.proxy.guardrails.guardrail_hooks.llm_as_a_judge.litellm.acompletion") async def test_apply_guardrail_passes_above_threshold(mock_completion): mock_completion.return_value = MagicMock( - choices=[ - MagicMock( - message=MagicMock(content=json.dumps(_make_verdict_response(90.0))) - ) - ] + choices=[MagicMock(message=MagicMock(content=json.dumps(_make_verdict_response(90.0))))] ) guardrail = _make_guardrail(overall_threshold=80.0) inputs = {"texts": ["good response"]} - request_data: dict = { - "messages": [{"role": "user", "content": "hi"}], - "metadata": {}, - } + request_data: dict = {"messages": [{"role": "user", "content": "hi"}], "metadata": {}} result = await guardrail.apply_guardrail(inputs, request_data, "response") assert result is inputs assert request_data["metadata"]["eval_information"]["passed"] is True @@ -207,11 +174,7 @@ async def test_apply_guardrail_passes_above_threshold(mock_completion): @patch("litellm.proxy.guardrails.guardrail_hooks.llm_as_a_judge.litellm.acompletion") async def test_apply_guardrail_blocks_below_threshold(mock_completion): mock_completion.return_value = MagicMock( - choices=[ - MagicMock( - message=MagicMock(content=json.dumps(_make_verdict_response(50.0))) - ) - ] + choices=[MagicMock(message=MagicMock(content=json.dumps(_make_verdict_response(50.0))))] ) guardrail = _make_guardrail(overall_threshold=80.0, on_failure="block") inputs = {"texts": ["bad response"]} @@ -225,11 +188,7 @@ async def test_apply_guardrail_blocks_below_threshold(mock_completion): @patch("litellm.proxy.guardrails.guardrail_hooks.llm_as_a_judge.litellm.acompletion") async def test_apply_guardrail_log_mode_does_not_block(mock_completion): mock_completion.return_value = MagicMock( - choices=[ - MagicMock( - message=MagicMock(content=json.dumps(_make_verdict_response(50.0))) - ) - ] + choices=[MagicMock(message=MagicMock(content=json.dumps(_make_verdict_response(50.0))))] ) guardrail = _make_guardrail(overall_threshold=80.0, on_failure="log") inputs = {"texts": ["bad response"]} diff --git a/tests/test_litellm/proxy/guardrails/test_qostodian_nexus_guardrail.py b/tests/test_litellm/proxy/guardrails/test_qostodian_nexus_guardrail.py index f7b4eb762a..6daa3e1430 100644 --- a/tests/test_litellm/proxy/guardrails/test_qostodian_nexus_guardrail.py +++ b/tests/test_litellm/proxy/guardrails/test_qostodian_nexus_guardrail.py @@ -232,9 +232,9 @@ def test_qostodian_nexus_builtin_extra_headers(): ] for header in expected_headers: - assert ( - header in instance.extra_headers - ), f"Expected built-in header '{header}' to be in extra_headers" + assert header in instance.extra_headers, ( + f"Expected built-in header '{header}' to be in extra_headers" + ) def test_qostodian_nexus_extra_headers_merged(): diff --git a/tests/test_litellm/proxy/hooks/test_batch_file_validation.py b/tests/test_litellm/proxy/hooks/test_batch_file_validation.py index c26e90108f..7f1006543b 100644 --- a/tests/test_litellm/proxy/hooks/test_batch_file_validation.py +++ b/tests/test_litellm/proxy/hooks/test_batch_file_validation.py @@ -14,6 +14,7 @@ from fastapi import HTTPException from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + # --------------------------------------------------------------------------- # Token counter — covers all three batch payload shapes # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/management_endpoints/test_activity_tenant_scoping.py b/tests/test_litellm/proxy/management_endpoints/test_activity_tenant_scoping.py index 4686c5e047..0855c20194 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_activity_tenant_scoping.py +++ b/tests/test_litellm/proxy/management_endpoints/test_activity_tenant_scoping.py @@ -14,6 +14,7 @@ import pytest from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + # --------------------------------------------------------------------------- # /team/daily/activity — per-team admin/permission requirement # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/management_endpoints/test_budget_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_budget_endpoints.py index 8f5e1da995..b15b9d622e 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_budget_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_budget_endpoints.py @@ -11,6 +11,7 @@ import litellm.proxy.proxy_server as ps from litellm.proxy.proxy_server import app from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles, CommonProxyErrors + sys.path.insert( 0, os.path.abspath("../../../") ) # Adds the parent directory to the system path diff --git a/tests/test_litellm/proxy/management_endpoints/test_delete_verification_tokens_failed.py b/tests/test_litellm/proxy/management_endpoints/test_delete_verification_tokens_failed.py index 6720002c2b..63e584e49b 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_delete_verification_tokens_failed.py +++ b/tests/test_litellm/proxy/management_endpoints/test_delete_verification_tokens_failed.py @@ -26,6 +26,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( delete_verification_tokens, ) + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/management_endpoints/test_project_org_authz.py b/tests/test_litellm/proxy/management_endpoints/test_project_org_authz.py index b3d5b7b4a0..bd982480d6 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_project_org_authz.py +++ b/tests/test_litellm/proxy/management_endpoints/test_project_org_authz.py @@ -14,6 +14,7 @@ from fastapi import HTTPException from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + # --------------------------------------------------------------------------- # /project/update — _check_user_permission_for_project # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/management_endpoints/test_team_default_params.py b/tests/test_litellm/proxy/management_endpoints/test_team_default_params.py index b0aae5ce39..443089b5f0 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_team_default_params.py +++ b/tests/test_litellm/proxy/management_endpoints/test_team_default_params.py @@ -25,6 +25,7 @@ from litellm.proxy.management_endpoints.team_endpoints import ( ) from litellm.proxy.proxy_server import ProxyConfig + # --------------------------------------------------------------------------- # _update_config_fields: default_team_params loaded from DB on startup # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/management_endpoints/test_workflow_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_workflow_management_endpoints.py index bf55fd560e..a337ff6d88 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_workflow_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_workflow_management_endpoints.py @@ -17,6 +17,7 @@ sys.path.insert(0, os.path.abspath("../../..")) from litellm.proxy.management_endpoints.workflow_management_endpoints import router + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/management_endpoints/usage_endpoints/test_ai_usage_chat.py b/tests/test_litellm/proxy/management_endpoints/usage_endpoints/test_ai_usage_chat.py index da899cd62f..e8a74e41da 100644 --- a/tests/test_litellm/proxy/management_endpoints/usage_endpoints/test_ai_usage_chat.py +++ b/tests/test_litellm/proxy/management_endpoints/usage_endpoints/test_ai_usage_chat.py @@ -17,6 +17,7 @@ from litellm.proxy.management_endpoints.usage_endpoints.ai_usage_chat import ( stream_usage_ai_chat, ) + SAMPLE_AGGREGATED_RESPONSE = { "results": [ { diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py b/tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py index 0e7e0ba321..c9be00afed 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py @@ -1092,9 +1092,9 @@ class TestVertexAIPassThroughHandler: assert result is not None assert result["result"] is not None - assert ( - result["kwargs"].get("custom_llm_provider") == "gemini" - ), "Google AI Studio embedContent URLs must set custom_llm_provider=gemini, not vertex_ai" + assert result["kwargs"].get("custom_llm_provider") == "gemini", ( + "Google AI Studio embedContent URLs must set custom_llm_provider=gemini, not vertex_ai" + ) assert result["kwargs"].get("model") == "gemini-embedding-2-preview" mock_completion_cost.assert_called_once() diff --git a/tests/test_litellm/proxy/test_fastapi_offline_routes.py b/tests/test_litellm/proxy/test_fastapi_offline_routes.py index 7788e54b0b..f3fc3d3ea2 100644 --- a/tests/test_litellm/proxy/test_fastapi_offline_routes.py +++ b/tests/test_litellm/proxy/test_fastapi_offline_routes.py @@ -1,7 +1,7 @@ """ Unit test for testing /routes endpoint with FastAPIOffline app initialization. -This test verifies that the /routes endpoint works correctly when the proxy +This test verifies that the /routes endpoint works correctly when the proxy server is initialized using FastAPIOffline instead of regular FastAPI. """ diff --git a/tests/test_litellm/proxy/test_filter_models_by_team_access_group.py b/tests/test_litellm/proxy/test_filter_models_by_team_access_group.py index 5d1767bfb0..2d8a9f30c1 100644 --- a/tests/test_litellm/proxy/test_filter_models_by_team_access_group.py +++ b/tests/test_litellm/proxy/test_filter_models_by_team_access_group.py @@ -233,6 +233,4 @@ async def test_filter_db_fallback_receives_resolved_model_names(): "gpt-4o", "gpt-5", }, f"DB query should receive resolved model names, got {queried_names}" - assert ( - "Group-A" not in queried_names - ), "Raw access group name should not be in DB query" + assert "Group-A" not in queried_names, "Raw access group name should not be in DB query" diff --git a/tests/test_litellm/proxy/test_model_level_guardrails.py b/tests/test_litellm/proxy/test_model_level_guardrails.py index 993f8b918d..3d74edd772 100644 --- a/tests/test_litellm/proxy/test_model_level_guardrails.py +++ b/tests/test_litellm/proxy/test_model_level_guardrails.py @@ -19,6 +19,7 @@ from litellm.proxy.utils import ( _merge_guardrails_with_existing, ) + # --------------------------------------------------------------------------- # Unit tests for _check_and_merge_model_level_guardrails # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/test_redis_auth_cache_flag.py b/tests/test_litellm/proxy/test_redis_auth_cache_flag.py index 3cb2f5e434..d0cb5ec546 100644 --- a/tests/test_litellm/proxy/test_redis_auth_cache_flag.py +++ b/tests/test_litellm/proxy/test_redis_auth_cache_flag.py @@ -16,6 +16,7 @@ import litellm.proxy.proxy_server as ps from litellm.caching.caching import RedisCache from litellm.caching.dual_cache import DualCache + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/test_team_org_move.py b/tests/test_litellm/proxy/test_team_org_move.py index ad5f9c3859..2dc961bec8 100644 --- a/tests/test_litellm/proxy/test_team_org_move.py +++ b/tests/test_litellm/proxy/test_team_org_move.py @@ -6,7 +6,6 @@ Covers the SSO/Entra scenario where: - Non-proxy-admins (team admins) must have all team members pre-added to the org, preserving the original security model (no privilege escalation via team move). """ - from unittest.mock import AsyncMock, MagicMock import pytest @@ -48,10 +47,10 @@ def _make_org(organization_id="org-1", members=None, models=None): def _make_team(team_id="team-1", member_ids=None, organization_id=None): - members = [Member(user_id=uid, role="user") for uid in (member_ids or [])] - members.append( - Member(user_id=SpecialProxyStrings.default_user_id.value, role="admin") - ) + members = [ + Member(user_id=uid, role="user") for uid in (member_ids or []) + ] + members.append(Member(user_id=SpecialProxyStrings.default_user_id.value, role="admin")) return LiteLLM_TeamTable( team_id=team_id, team_alias="test-team", @@ -121,18 +120,12 @@ class TestValidateTeamOrgChange: team = _make_team(member_ids=["u1"], organization_id="org-1") org = _make_org(organization_id="org-1") - assert ( - validate_team_org_change( - team=team, organization=org, llm_router=router, is_proxy_admin=False - ) - is True - ) - assert ( - validate_team_org_change( - team=team, organization=org, llm_router=router, is_proxy_admin=True - ) - is True - ) + assert validate_team_org_change( + team=team, organization=org, llm_router=router, is_proxy_admin=False + ) is True + assert validate_team_org_change( + team=team, organization=org, llm_router=router, is_proxy_admin=True + ) is True def test_default_user_excluded_from_membership_check(self): """default_user_id is never checked for org membership.""" @@ -156,7 +149,6 @@ class TestAutoAddTeamMembersToOrg: mock_add = AsyncMock() import litellm.proxy.management_endpoints.team_endpoints as te - original = te.add_member_to_organization te.add_member_to_organization = mock_add @@ -182,7 +174,6 @@ class TestAutoAddTeamMembersToOrg: mock_add = AsyncMock() import litellm.proxy.management_endpoints.team_endpoints as te - original = te.add_member_to_organization te.add_member_to_organization = mock_add @@ -206,7 +197,6 @@ class TestAutoAddTeamMembersToOrg: mock_add = AsyncMock() import litellm.proxy.management_endpoints.team_endpoints as te - original = te.add_member_to_organization te.add_member_to_organization = mock_add @@ -229,7 +219,6 @@ class TestAutoAddTeamMembersToOrg: mock_add = AsyncMock(side_effect=Exception("duplicate key")) import litellm.proxy.management_endpoints.team_endpoints as te - original = te.add_member_to_organization te.add_member_to_organization = mock_add diff --git a/tests/test_litellm/responses/test_responses_api_bridge_flag.py b/tests/test_litellm/responses/test_responses_api_bridge_flag.py index ca3a2a8a64..463af6562f 100644 --- a/tests/test_litellm/responses/test_responses_api_bridge_flag.py +++ b/tests/test_litellm/responses/test_responses_api_bridge_flag.py @@ -151,7 +151,9 @@ class TestUseResponsesApiBridgeFlag: output=[ {"type": "message", "content": [{"type": "text", "text": "Answer"}]} ], - usage=ResponseAPIUsage(input_tokens=10, output_tokens=5, total_tokens=15), + usage=ResponseAPIUsage( + input_tokens=10, output_tokens=5, total_tokens=15 + ), ) mock_call_aresponses.return_value = mock_response @@ -200,7 +202,9 @@ class TestUseResponsesApiBridgeFlag: "arguments": '{"queries": ["test query"]}', } ], - usage=ResponseAPIUsage(input_tokens=10, output_tokens=5, total_tokens=15), + usage=ResponseAPIUsage( + input_tokens=10, output_tokens=5, total_tokens=15 + ), ) second_response = ResponsesAPIResponse( id="resp_second", @@ -212,7 +216,9 @@ class TestUseResponsesApiBridgeFlag: "content": [{"type": "text", "text": "Final answer"}], } ], - usage=ResponseAPIUsage(input_tokens=20, output_tokens=10, total_tokens=30), + usage=ResponseAPIUsage( + input_tokens=20, output_tokens=10, total_tokens=30 + ), ) mock_bridge_handler.side_effect = [first_response, second_response] @@ -261,7 +267,9 @@ class TestUseResponsesApiBridgeFlag: "content": [{"type": "text", "text": "Native response"}], } ], - usage=ResponseAPIUsage(input_tokens=10, output_tokens=5, total_tokens=15), + usage=ResponseAPIUsage( + input_tokens=10, output_tokens=5, total_tokens=15 + ), ) result = await litellm.aresponses( diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_router_dispatch.py b/tests/test_litellm/router_strategy/adaptive_router/test_router_dispatch.py index 2703d7b28c..604155e122 100644 --- a/tests/test_litellm/router_strategy/adaptive_router/test_router_dispatch.py +++ b/tests/test_litellm/router_strategy/adaptive_router/test_router_dispatch.py @@ -454,7 +454,9 @@ def test_finalize_prunes_stale_adaptive_router_hooks_from_callbacks(): Router(model_list=model_list) # simulate hot-reload adaptive_hooks = [ - cb for cb in litellm.callbacks if isinstance(cb, AdaptiveRouterPostCallHook) + cb + for cb in litellm.callbacks + if isinstance(cb, AdaptiveRouterPostCallHook) ] assert len(adaptive_hooks) == 1, ( f"expected exactly one AdaptiveRouterPostCallHook after hot-reload, " diff --git a/tests/test_litellm/test_acompletion_session_reuse_e2e.py b/tests/test_litellm/test_acompletion_session_reuse_e2e.py index a6f2be09ce..79b947bb14 100644 --- a/tests/test_litellm/test_acompletion_session_reuse_e2e.py +++ b/tests/test_litellm/test_acompletion_session_reuse_e2e.py @@ -22,6 +22,7 @@ sys.path.insert(0, os.path.abspath("../../..")) import litellm + # ============================================================================ # HELPER FUNCTION # ============================================================================ diff --git a/tests/test_litellm/test_anthropic_skills_transformation.py b/tests/test_litellm/test_anthropic_skills_transformation.py index 1b41b05f6d..1b917f08ca 100644 --- a/tests/test_litellm/test_anthropic_skills_transformation.py +++ b/tests/test_litellm/test_anthropic_skills_transformation.py @@ -22,6 +22,7 @@ from litellm.types.llms.anthropic_skills import ( ) from litellm.types.router import GenericLiteLLMParams + FAKE_API_KEY = "sk-ant-test-key-1234" FAKE_API_BASE = "https://api.anthropic.com" diff --git a/tests/test_litellm/test_dashscope_image_generation.py b/tests/test_litellm/test_dashscope_image_generation.py index aa3c955f83..af95e2ca6b 100644 --- a/tests/test_litellm/test_dashscope_image_generation.py +++ b/tests/test_litellm/test_dashscope_image_generation.py @@ -18,6 +18,7 @@ from litellm.llms.dashscope.image_generation.transformation import ( from litellm.types.utils import ImageObject, ImageResponse from litellm.utils import get_llm_provider + # --------------------------------------------------------------------------- # 1. Provider detection # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/test_deepseek_model_metadata.py b/tests/test_litellm/test_deepseek_model_metadata.py index 2b86b61b01..4900af5d97 100644 --- a/tests/test_litellm/test_deepseek_model_metadata.py +++ b/tests/test_litellm/test_deepseek_model_metadata.py @@ -23,6 +23,7 @@ from litellm.utils import ( supports_response_schema, ) + # --------------------------------------------------------------------------- # Data-level tests – verify the JSON files are in sync # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/test_model_cost_aliases.py b/tests/test_litellm/test_model_cost_aliases.py index ee11ebb970..f9f92a85cb 100644 --- a/tests/test_litellm/test_model_cost_aliases.py +++ b/tests/test_litellm/test_model_cost_aliases.py @@ -10,6 +10,7 @@ from unittest.mock import patch from litellm import verbose_logger from litellm.litellm_core_utils.get_model_cost_map import _expand_model_aliases + # --------------------------------------------------------------------------- # Core expansion behaviour # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/test_nested_drop_params.py b/tests/test_litellm/test_nested_drop_params.py index fbcd530230..bb1305ffde 100644 --- a/tests/test_litellm/test_nested_drop_params.py +++ b/tests/test_litellm/test_nested_drop_params.py @@ -7,6 +7,7 @@ This tests the new JSONPath-like syntax for removing nested fields. import os import sys + # Add parent directory to path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) diff --git a/tests/test_litellm/test_openai_embedding_encoding_format_default.py b/tests/test_litellm/test_openai_embedding_encoding_format_default.py index 1cf71139d8..94e4e3c81e 100644 --- a/tests/test_litellm/test_openai_embedding_encoding_format_default.py +++ b/tests/test_litellm/test_openai_embedding_encoding_format_default.py @@ -51,7 +51,9 @@ def test_openai_embedding_encoding_format_default( @pytest.mark.parametrize("env_none", ["none", "NONE", " none "]) -def test_openai_embedding_encoding_format_env_none_omits_param(monkeypatch, env_none): +def test_openai_embedding_encoding_format_env_none_omits_param( + monkeypatch, env_none +): """LITELLM_DEFAULT_EMBEDDING_ENCODING_FORMAT=none omits encoding_format (provider default).""" monkeypatch.setenv("LITELLM_DEFAULT_EMBEDDING_ENCODING_FORMAT", env_none) diff --git a/tests/test_litellm/test_router_google_genai.py b/tests/test_litellm/test_router_google_genai.py index b8179e568a..81dd7bbdc4 100644 --- a/tests/test_litellm/test_router_google_genai.py +++ b/tests/test_litellm/test_router_google_genai.py @@ -2,7 +2,6 @@ """ Test to verify the new Google GenAI router methods """ - import asyncio import os import sys diff --git a/tests/test_litellm/test_streaming_connection_cleanup.py b/tests/test_litellm/test_streaming_connection_cleanup.py index 53bbe8abd3..5a81a3ffb1 100644 --- a/tests/test_litellm/test_streaming_connection_cleanup.py +++ b/tests/test_litellm/test_streaming_connection_cleanup.py @@ -19,6 +19,7 @@ from litellm.llms.custom_httpx.aiohttp_transport import ( LiteLLMAiohttpTransport, ) + # ── aiohttp transport layer tests ────────────────────────────── diff --git a/tests/test_litellm/test_utils.py b/tests/test_litellm/test_utils.py index f493b84f44..65305e1a81 100644 --- a/tests/test_litellm/test_utils.py +++ b/tests/test_litellm/test_utils.py @@ -1500,7 +1500,8 @@ class TestProxyFunctionCalling: assert result is True, "Resolvable model names work with fallback logic" # Documentation notes: - print(""" + print( + """ PROXY MODEL RESOLUTION BEHAVIOR: ✅ WORKS (with current fallback logic): @@ -1515,7 +1516,8 @@ class TestProxyFunctionCalling: 💡 SOLUTION: Use LiteLLM proxy server with proper model_list configuration that maps custom names to underlying models. - """) + """ + ) @pytest.mark.parametrize( "proxy_model_with_hints,expected_result", @@ -1877,7 +1879,8 @@ class TestProxyFunctionCalling: This test provides documentation on how the proxy server configuration would typically map custom model names to underlying models. """ - print(""" + print( + """ REAL-WORLD PROXY SERVER CONFIGURATION EXAMPLE: =============================================== @@ -1930,7 +1933,8 @@ class TestProxyFunctionCalling: - Consistent request/response format - Enhanced streaming support for function calls - """) + """ + ) # Verify that direct underlying models work as expected bedrock_models = [ @@ -2144,7 +2148,8 @@ class TestProxyFunctionCalling: This test provides documentation on how the proxy server configuration would typically map custom model names to underlying models. """ - print(""" + print( + """ REAL-WORLD PROXY SERVER CONFIGURATION EXAMPLE: =============================================== @@ -2197,7 +2202,8 @@ class TestProxyFunctionCalling: - Consistent request/response format - Enhanced streaming support for function calls - """) + """ + ) # Verify that direct underlying models work as expected bedrock_models = [ @@ -2411,7 +2417,8 @@ class TestProxyFunctionCalling: This test provides documentation on how the proxy server configuration would typically map custom model names to underlying models. """ - print(""" + print( + """ REAL-WORLD PROXY SERVER CONFIGURATION EXAMPLE: =============================================== @@ -2464,7 +2471,8 @@ class TestProxyFunctionCalling: - Consistent request/response format - Enhanced streaming support for function calls - """) + """ + ) # Verify that direct underlying models work as expected bedrock_models = [ diff --git a/tests/test_litellm/types/test_completion.py b/tests/test_litellm/types/test_completion.py index b753a8e2ab..f24b00df3f 100644 --- a/tests/test_litellm/types/test_completion.py +++ b/tests/test_litellm/types/test_completion.py @@ -1,7 +1,7 @@ """ Tests for litellm.types.completion module -This test suite validates the CompletionRequest model and its compatibility with +This test suite validates the CompletionRequest model and its compatibility with OpenAI ChatCompletion API message formats. Usage: diff --git a/tests/test_litellm/types/test_prometheus_label_value_sanitize.py b/tests/test_litellm/types/test_prometheus_label_value_sanitize.py index d8b90197c6..9ff7eb460e 100644 --- a/tests/test_litellm/types/test_prometheus_label_value_sanitize.py +++ b/tests/test_litellm/types/test_prometheus_label_value_sanitize.py @@ -22,7 +22,7 @@ from litellm.types.integrations.prometheus import ( # Escapes per Prometheus text format ('he said "hi"', 'he said \\"hi\\"'), (r"path\to\file", r"path\\to\\file"), - (r"quote\"slash\\", r"quote\\\"slash\\\\"), + (r'quote\"slash\\', r'quote\\\"slash\\\\'), # Non-string inputs get coerced to str first (123, "123"), (True, "True"), @@ -31,3 +31,4 @@ from litellm.types.integrations.prometheus import ( ) def test_sanitize_prometheus_label_value_expected_outputs(value, expected): assert _sanitize_prometheus_label_value(value) == expected + diff --git a/tests/test_passthrough_endpoints.py b/tests/test_passthrough_endpoints.py index bf9417912c..47ac7511aa 100644 --- a/tests/test_passthrough_endpoints.py +++ b/tests/test_passthrough_endpoints.py @@ -10,6 +10,7 @@ import json import os import dotenv + dotenv.load_dotenv() diff --git a/tests/vector_store_tests/rag/base_rag_tests.py b/tests/vector_store_tests/rag/base_rag_tests.py index c357965da5..2c5a2540a7 100644 --- a/tests/vector_store_tests/rag/base_rag_tests.py +++ b/tests/vector_store_tests/rag/base_rag_tests.py @@ -124,7 +124,9 @@ class BaseRAGTest(ABC): Test document {unique_id} for RAG ingestion and query. LiteLLM provides a unified interface for 100+ LLMs. This content should be retrievable via semantic search. - """.encode("utf-8") + """.encode( + "utf-8" + ) file_data = (filename, text_content, "text/plain") ingest_options = self.get_base_ingest_options() diff --git a/tests/vector_store_tests/rag/test_rag_vertex_ai.py b/tests/vector_store_tests/rag/test_rag_vertex_ai.py index 872383899d..c99840bb0f 100644 --- a/tests/vector_store_tests/rag/test_rag_vertex_ai.py +++ b/tests/vector_store_tests/rag/test_rag_vertex_ai.py @@ -151,7 +151,9 @@ class TestRAGVertexAI(BaseRAGTest): Test document {unique_id} for Vertex AI RAG corpus creation. This tests the automatic corpus creation feature. The corpus should be created and the file should be uploaded successfully. - """.encode("utf-8") + """.encode( + "utf-8" + ) file_data = (filename, text_content, "text/plain") # Get base options WITHOUT corpus_id to trigger creation @@ -213,7 +215,9 @@ class TestRAGVertexAI(BaseRAGTest): text_content = f""" Test document {unique_id} for existing Vertex AI RAG corpus. This tests file upload to a pre-existing corpus. - """.encode("utf-8") + """.encode( + "utf-8" + ) file_data = (filename, text_content, "text/plain") ingest_options = self.get_base_ingest_options() diff --git a/tests/vector_store_tests/test_milvus_vector_store.py b/tests/vector_store_tests/test_milvus_vector_store.py index a0faf9d4ec..6627f6006d 100644 --- a/tests/vector_store_tests/test_milvus_vector_store.py +++ b/tests/vector_store_tests/test_milvus_vector_store.py @@ -12,6 +12,7 @@ import litellm from litellm.vector_stores import asearch as vector_store_asearch from litellm.vector_stores import search as vector_store_search + # Mock response from actual Milvus API MOCK_MILVUS_SEARCH_RESPONSE = { "code": 0, diff --git a/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py b/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py index 94c92c7fb9..7cace33861 100644 --- a/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py +++ b/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py @@ -7,6 +7,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch import litellm + # Mock response from actual Vertex AI Search API MOCK_VERTEX_SEARCH_RESPONSE = { "results": [ diff --git a/ui/litellm-dashboard/e2e_tests/fixtures/mock_llm_server/server.py b/ui/litellm-dashboard/e2e_tests/fixtures/mock_llm_server/server.py index 9ca335430f..8e92065c69 100644 --- a/ui/litellm-dashboard/e2e_tests/fixtures/mock_llm_server/server.py +++ b/ui/litellm-dashboard/e2e_tests/fixtures/mock_llm_server/server.py @@ -12,6 +12,7 @@ from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse + app = FastAPI(title="Mock LLM Server") app.add_middleware( CORSMiddleware,