diff --git a/enterprise/LICENSE.md b/enterprise/LICENSE.md deleted file mode 100644 index c8607439ad..0000000000 --- a/enterprise/LICENSE.md +++ /dev/null @@ -1,37 +0,0 @@ - -The BerriAI Enterprise license (the "Enterprise License") -Copyright (c) 2024 - present Berrie AI Inc. - -With regard to the BerriAI Software: - -This software and associated documentation files (the "Software") may only be -used in production, if you (and any entity that you represent) have agreed to, -and are in compliance with, the BerriAI Subscription Terms of Service, available -via [call](https://enterprise.litellm.ai/demo) or email (info@berri.ai) (the "Enterprise Terms"), or other -agreement governing the use of the Software, as agreed by you and BerriAI, -and otherwise have a valid BerriAI Enterprise license for the -correct number of user seats. Subject to the foregoing sentence, you are free to -modify this Software and publish patches to the Software. You agree that BerriAI -and/or its licensors (as applicable) retain all right, title and interest in and -to all such modifications and/or patches, and all such modifications and/or -patches may only be used, copied, modified, displayed, distributed, or otherwise -exploited with a valid BerriAI Enterprise license for the correct -number of user seats. Notwithstanding the foregoing, you may copy and modify -the Software for development and testing purposes, without requiring a -subscription. You agree that BerriAI and/or its licensors (as applicable) retain -all right, title and interest in and to all such modifications. You are not -granted any other rights beyond what is expressly stated herein. Subject to the -foregoing, it is forbidden to copy, merge, publish, distribute, sublicense, -and/or sell the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -For all third party components incorporated into the BerriAI Software, those -components are licensed under the original license provided by the owner of the -applicable component. \ No newline at end of file diff --git a/enterprise/README.md b/enterprise/README.md deleted file mode 100644 index f5eb5078e8..0000000000 --- a/enterprise/README.md +++ /dev/null @@ -1,9 +0,0 @@ -## LiteLLM Enterprise - -Code in this folder is licensed under a commercial license. Please review the [LICENSE](./LICENSE.md) file within the /enterprise folder - -**These features are covered under the LiteLLM Enterprise contract** - -👉 **Using in an Enterprise / Need specific features ?** Meet with us [here](https://enterprise.litellm.ai/demo?month=2024-02) - -See all Enterprise Features here 👉 [Docs](https://docs.litellm.ai/docs/proxy/enterprise) diff --git a/enterprise/__init__.py b/enterprise/__init__.py deleted file mode 100644 index b6e690fd59..0000000000 --- a/enterprise/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import * diff --git a/enterprise/cloudformation_stack/litellm.yaml b/enterprise/cloudformation_stack/litellm.yaml deleted file mode 100644 index c30956b945..0000000000 --- a/enterprise/cloudformation_stack/litellm.yaml +++ /dev/null @@ -1,44 +0,0 @@ -Resources: - LiteLLMServer: - Type: AWS::EC2::Instance - Properties: - AvailabilityZone: us-east-1a - ImageId: ami-0f403e3180720dd7e - InstanceType: t2.micro - - LiteLLMServerAutoScalingGroup: - Type: AWS::AutoScaling::AutoScalingGroup - Properties: - AvailabilityZones: - - us-east-1a - LaunchConfigurationName: !Ref LiteLLMServerLaunchConfig - MinSize: 1 - MaxSize: 3 - DesiredCapacity: 1 - HealthCheckGracePeriod: 300 - - LiteLLMServerLaunchConfig: - Type: AWS::AutoScaling::LaunchConfiguration - Properties: - ImageId: ami-0f403e3180720dd7e # Replace with your desired AMI ID - InstanceType: t2.micro - - LiteLLMServerScalingPolicy: - Type: AWS::AutoScaling::ScalingPolicy - Properties: - AutoScalingGroupName: !Ref LiteLLMServerAutoScalingGroup - PolicyType: TargetTrackingScaling - TargetTrackingConfiguration: - PredefinedMetricSpecification: - PredefinedMetricType: ASGAverageCPUUtilization - TargetValue: 60.0 - - LiteLLMDB: - Type: AWS::RDS::DBInstance - Properties: - AllocatedStorage: 20 - Engine: postgres - MasterUsername: litellmAdmin - MasterUserPassword: litellmPassword - DBInstanceClass: db.t3.micro - AvailabilityZone: us-east-1a \ No newline at end of file diff --git a/enterprise/dist/litellm_enterprise-0.1.1-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.1-py3-none-any.whl deleted file mode 100644 index d9a8ef41e6..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.1-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.1.tar.gz b/enterprise/dist/litellm_enterprise-0.1.1.tar.gz deleted file mode 100644 index 98cf132b21..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.1.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.10-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.10-py3-none-any.whl deleted file mode 100644 index 473ff736e3..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.10-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.10.tar.gz b/enterprise/dist/litellm_enterprise-0.1.10.tar.gz deleted file mode 100644 index e28ee65c38..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.10.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.11-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.11-py3-none-any.whl deleted file mode 100644 index 3dece3053d..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.11-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.11.tar.gz b/enterprise/dist/litellm_enterprise-0.1.11.tar.gz deleted file mode 100644 index 02b62c3dda..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.11.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.12-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.12-py3-none-any.whl deleted file mode 100644 index 9f72a92014..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.12-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.12.tar.gz b/enterprise/dist/litellm_enterprise-0.1.12.tar.gz deleted file mode 100644 index cbaeff7d77..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.12.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.13-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.13-py3-none-any.whl deleted file mode 100644 index e9f350030b..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.13-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.13.tar.gz b/enterprise/dist/litellm_enterprise-0.1.13.tar.gz deleted file mode 100644 index bde63337ab..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.13.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.15-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.15-py3-none-any.whl deleted file mode 100644 index 99381c7f65..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.15-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.15.tar.gz b/enterprise/dist/litellm_enterprise-0.1.15.tar.gz deleted file mode 100644 index 794a6a1b87..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.15.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.17-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.17-py3-none-any.whl deleted file mode 100644 index 9c2856b465..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.17-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.17.tar.gz b/enterprise/dist/litellm_enterprise-0.1.17.tar.gz deleted file mode 100644 index 92d4a6ee92..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.17.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.19-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.19-py3-none-any.whl deleted file mode 100644 index 5b48b65e4d..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.19-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.19.tar.gz b/enterprise/dist/litellm_enterprise-0.1.19.tar.gz deleted file mode 100644 index 2f99960bde..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.19.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.2-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.2-py3-none-any.whl deleted file mode 100644 index 1f75e0f1b5..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.2-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.2.tar.gz b/enterprise/dist/litellm_enterprise-0.1.2.tar.gz deleted file mode 100644 index b6fa4dd5f7..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.2.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.21-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.21-py3-none-any.whl deleted file mode 100644 index 6452930c9f..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.21-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.21.tar.gz b/enterprise/dist/litellm_enterprise-0.1.21.tar.gz deleted file mode 100644 index ed6ebc3834..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.21.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.22-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.22-py3-none-any.whl deleted file mode 100644 index 6ad5b7041c..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.22-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.22.tar.gz b/enterprise/dist/litellm_enterprise-0.1.22.tar.gz deleted file mode 100644 index 9db2c14b12..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.22.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.23-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.23-py3-none-any.whl deleted file mode 100644 index c061e793bc..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.23-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.23.tar.gz b/enterprise/dist/litellm_enterprise-0.1.23.tar.gz deleted file mode 100644 index b84c2ba0f2..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.23.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.24-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.24-py3-none-any.whl deleted file mode 100644 index a26b0458c9..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.24-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.24.tar.gz b/enterprise/dist/litellm_enterprise-0.1.24.tar.gz deleted file mode 100644 index 4361910f4b..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.24.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.25-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.25-py3-none-any.whl deleted file mode 100644 index bcc559d21b..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.25-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.25.tar.gz b/enterprise/dist/litellm_enterprise-0.1.25.tar.gz deleted file mode 100644 index 4db1cf7ef5..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.25.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.26-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.26-py3-none-any.whl deleted file mode 100644 index e4cfac6553..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.26-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.26.tar.gz b/enterprise/dist/litellm_enterprise-0.1.26.tar.gz deleted file mode 100644 index c8e0081ff1..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.26.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.27-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.27-py3-none-any.whl deleted file mode 100644 index 0274d62e16..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.27-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.27.tar.gz b/enterprise/dist/litellm_enterprise-0.1.27.tar.gz deleted file mode 100644 index d802b5a89d..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.27.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.29-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.29-py3-none-any.whl deleted file mode 100644 index 0895ecbc42..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.29-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.29.tar.gz b/enterprise/dist/litellm_enterprise-0.1.29.tar.gz deleted file mode 100644 index 6781cf26cc..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.29.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.3-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.3-py3-none-any.whl deleted file mode 100644 index 7b5cb85656..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.3-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.3.tar.gz b/enterprise/dist/litellm_enterprise-0.1.3.tar.gz deleted file mode 100644 index d5ac9f26a4..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.3.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.30-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.30-py3-none-any.whl deleted file mode 100644 index 0165bb096c..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.30-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.30.tar.gz b/enterprise/dist/litellm_enterprise-0.1.30.tar.gz deleted file mode 100644 index 2bb7510e5d..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.30.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.31-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.31-py3-none-any.whl deleted file mode 100644 index 03cadbd902..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.31-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.31.tar.gz b/enterprise/dist/litellm_enterprise-0.1.31.tar.gz deleted file mode 100644 index 1ba1a717f6..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.31.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.32-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.32-py3-none-any.whl deleted file mode 100644 index 0c87c72c98..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.32-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.32.tar.gz b/enterprise/dist/litellm_enterprise-0.1.32.tar.gz deleted file mode 100644 index 4f0ac1a9b2..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.32.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.4-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.4-py3-none-any.whl deleted file mode 100644 index f862a55b18..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.4-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.4.tar.gz b/enterprise/dist/litellm_enterprise-0.1.4.tar.gz deleted file mode 100644 index bf1b3ec57c..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.4.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.5-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.5-py3-none-any.whl deleted file mode 100644 index 661638db3f..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.5-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.5.tar.gz b/enterprise/dist/litellm_enterprise-0.1.5.tar.gz deleted file mode 100644 index 2808574dda..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.5.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.6-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.6-py3-none-any.whl deleted file mode 100644 index c212c7e5a3..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.6-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.6.tar.gz b/enterprise/dist/litellm_enterprise-0.1.6.tar.gz deleted file mode 100644 index 698a9da209..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.6.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.7-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.7-py3-none-any.whl deleted file mode 100644 index 248e1ca294..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.7-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.7.tar.gz b/enterprise/dist/litellm_enterprise-0.1.7.tar.gz deleted file mode 100644 index 7c28d3a36a..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.7.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.8-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.8-py3-none-any.whl deleted file mode 100644 index b9470dca46..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.8-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.8.tar.gz b/enterprise/dist/litellm_enterprise-0.1.8.tar.gz deleted file mode 100644 index f233be2be8..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.8.tar.gz and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.9-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.9-py3-none-any.whl deleted file mode 100644 index eb4b9d1083..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.9-py3-none-any.whl and /dev/null differ diff --git a/enterprise/dist/litellm_enterprise-0.1.9.tar.gz b/enterprise/dist/litellm_enterprise-0.1.9.tar.gz deleted file mode 100644 index 748ed2150e..0000000000 Binary files a/enterprise/dist/litellm_enterprise-0.1.9.tar.gz and /dev/null differ diff --git a/enterprise/enterprise_hooks/__init__.py b/enterprise/enterprise_hooks/__init__.py deleted file mode 100644 index e93c8c9150..0000000000 --- a/enterprise/enterprise_hooks/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Dict, Literal, Type, Union - -from litellm_enterprise.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles -from litellm_enterprise.proxy.hooks.managed_vector_stores import ( - _PROXY_LiteLLMManagedVectorStores, -) - -from litellm.integrations.custom_logger import CustomLogger - -ENTERPRISE_PROXY_HOOKS: Dict[str, Type[CustomLogger]] = { - "managed_files": _PROXY_LiteLLMManagedFiles, - "managed_vector_stores": _PROXY_LiteLLMManagedVectorStores, -} - - -def get_enterprise_proxy_hook( - hook_name: Union[ - Literal[ - "managed_files", - "managed_vector_stores", - "max_parallel_requests", - ], - str, - ], -): - """ - Factory method to get a enterprise hook instance by name - """ - if hook_name not in ENTERPRISE_PROXY_HOOKS: - raise ValueError( - f"Unknown hook: {hook_name}. Available hooks: {list(ENTERPRISE_PROXY_HOOKS.keys())}" - ) - return ENTERPRISE_PROXY_HOOKS[hook_name] diff --git a/enterprise/enterprise_hooks/aporia_ai.py b/enterprise/enterprise_hooks/aporia_ai.py deleted file mode 100644 index 28b49bfce2..0000000000 --- a/enterprise/enterprise_hooks/aporia_ai.py +++ /dev/null @@ -1,204 +0,0 @@ -# +-------------------------------------------------------------+ -# -# Use AporiaAI for your LLM calls -# -# +-------------------------------------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan - -import os -import sys - -from litellm.types.utils import CallTypesLiteral - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import json -import sys -from typing import Any, List, Literal, Optional - -from fastapi import HTTPException - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.litellm_core_utils.logging_utils import ( - convert_litellm_response_object_to_str, -) -from litellm.llms.custom_httpx.http_handler import ( - get_async_httpx_client, - httpxSpecialProvider, -) -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata -from litellm.types.guardrails import GuardrailEventHooks - -GUARDRAIL_NAME = "aporia" - - -class AporiaGuardrail(CustomGuardrail): - def __init__( - self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs - ): - self.async_handler = get_async_httpx_client( - llm_provider=httpxSpecialProvider.GuardrailCallback - ) - self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] - self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] - super().__init__(**kwargs) - - #### CALL HOOKS - proxy only #### - def transform_messages(self, messages: List[dict]) -> List[dict]: - supported_openai_roles = ["system", "user", "assistant"] - default_role = "other" # for unsupported roles - e.g. tool - new_messages = [] - for m in messages: - if m.get("role", "") in supported_openai_roles: - new_messages.append(m) - else: - new_messages.append( - { - "role": default_role, - **{key: value for key, value in m.items() if key != "role"}, - } - ) - - return new_messages - - async def prepare_aporia_request( - self, new_messages: List[dict], response_string: Optional[str] = None - ) -> dict: - data: dict[str, Any] = {} - if new_messages is not None: - data["messages"] = new_messages - if response_string is not None: - data["response"] = response_string - - # Set validation target - if new_messages and response_string: - data["validation_target"] = "both" - elif new_messages: - data["validation_target"] = "prompt" - elif response_string: - data["validation_target"] = "response" - - verbose_proxy_logger.debug("Aporia AI request: %s", data) - return data - - async def make_aporia_api_request( - self, new_messages: List[dict], response_string: Optional[str] = None - ): - data = await self.prepare_aporia_request( - new_messages=new_messages, response_string=response_string - ) - - _json_data = json.dumps(data) - - """ - export APORIO_API_KEY= - curl https://gr-prd-trial.aporia.com/some-id \ - -X POST \ - -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [ - { - "role": "user", - "content": "This is a test prompt" - } - ], - } -' - """ - - response = await self.async_handler.post( - url=self.aporia_api_base + "/validate", - data=_json_data, - headers={ - "X-APORIA-API-KEY": self.aporia_api_key, - "Content-Type": "application/json", - }, - ) - verbose_proxy_logger.debug("Aporia AI response: %s", response.text) - if response.status_code == 200: - # check if the response was flagged - _json_response = response.json() - action: str = _json_response.get( - "action" - ) # possible values are modify, passthrough, block, rephrase - if action == "block": - raise HTTPException( - status_code=400, - detail={ - "error": "Violated guardrail policy", - "aporia_ai_response": _json_response, - }, - ) - - async def async_post_call_success_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - response, - ): - from litellm.proxy.common_utils.callback_utils import ( - add_guardrail_to_applied_guardrails_header, - ) - - """ - Use this for the post call moderation with Guardrails - """ - event_type: GuardrailEventHooks = GuardrailEventHooks.post_call - if self.should_run_guardrail(data=data, event_type=event_type) is not True: - return - - response_str: Optional[str] = convert_litellm_response_object_to_str(response) - if response_str is not None: - await self.make_aporia_api_request( - response_string=response_str, new_messages=data.get("messages", []) - ) - - add_guardrail_to_applied_guardrails_header( - request_data=data, guardrail_name=self.guardrail_name - ) - - pass - - async def async_moderation_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - call_type: CallTypesLiteral, - ): - from litellm.proxy.common_utils.callback_utils import ( - add_guardrail_to_applied_guardrails_header, - ) - - event_type: GuardrailEventHooks = GuardrailEventHooks.during_call - if self.should_run_guardrail(data=data, event_type=event_type) is not True: - return - - # old implementation - backwards compatibility - if ( - await should_proceed_based_on_metadata( - data=data, - guardrail_name=GUARDRAIL_NAME, - ) - is False - ): - return - - new_messages: Optional[List[dict]] = None - if "messages" in data and isinstance(data["messages"], list): - new_messages = self.transform_messages(messages=data["messages"]) - - if new_messages is not None: - await self.make_aporia_api_request(new_messages=new_messages) - add_guardrail_to_applied_guardrails_header( - request_data=data, guardrail_name=self.guardrail_name - ) - else: - verbose_proxy_logger.warning( - "Aporia AI: not running guardrail. No messages in data" - ) - pass diff --git a/enterprise/enterprise_hooks/banned_keywords.py b/enterprise/enterprise_hooks/banned_keywords.py deleted file mode 100644 index 47421c9605..0000000000 --- a/enterprise/enterprise_hooks/banned_keywords.py +++ /dev/null @@ -1,115 +0,0 @@ -# +------------------------------+ -# -# Banned Keywords -# -# +------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan -## Reject a call / response if it contains certain keywords - - -from typing import Literal -import litellm -from litellm.caching.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.guardrails._content_utils import ( - is_text_content_call_type, - iter_message_text, -) -from litellm.integrations.custom_logger import CustomLogger -from litellm._logging import verbose_proxy_logger -from fastapi import HTTPException - - -class _ENTERPRISE_BannedKeywords(CustomLogger): - # Class variables or attributes - def __init__(self): - banned_keywords_list = litellm.banned_keywords_list - - if banned_keywords_list is None: - raise Exception( - "`banned_keywords_list` can either be a list or filepath. None set." - ) - - if isinstance(banned_keywords_list, list): - self.banned_keywords_list = banned_keywords_list - - if isinstance(banned_keywords_list, str): # assume it's a filepath - try: - with open(banned_keywords_list, "r") as file: - data = file.read() - self.banned_keywords_list = data.split("\n") - except FileNotFoundError: - raise Exception( - f"File not found. banned_keywords_list={banned_keywords_list}" - ) - except Exception as e: - raise Exception( - f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}" - ) - - def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): - if level == "INFO": - verbose_proxy_logger.info(print_statement) - elif level == "DEBUG": - verbose_proxy_logger.debug(print_statement) - - if litellm.set_verbose is True: - print(print_statement) # noqa - - def test_violation(self, test_str: str): - for word in self.banned_keywords_list: - if word in test_str.lower(): - raise HTTPException( - status_code=400, - detail={"error": f"Keyword banned. Keyword={word}"}, - ) - - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: dict, - call_type: str, # "completion", "embeddings", "image_generation", "moderation" - ): - try: - """ - - check if user id part of call - - check if user id part of blocked list - """ - self.print_verbose("Inside Banned Keyword List Pre-Call Hook") - if is_text_content_call_type(call_type): - for text in iter_message_text(data): - self.test_violation(test_str=text) - - except HTTPException as e: - raise e - except Exception as e: - verbose_proxy_logger.exception( - "litellm.enterprise.enterprise_hooks.banned_keywords::async_pre_call_hook - Exception occurred - {}".format( - str(e) - ) - ) - - async def async_post_call_success_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - response, - ): - if not isinstance(response, litellm.ModelResponse): - return - - for choice in response.choices: - if not isinstance(choice, litellm.utils.Choices): - continue - message = getattr(choice, "message", None) - content = getattr(message, "content", None) - if isinstance(content, str): - self.test_violation(test_str=content) - - async def async_post_call_streaming_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - response: str, - ): - self.test_violation(test_str=response) diff --git a/enterprise/enterprise_hooks/blocked_user_list.py b/enterprise/enterprise_hooks/blocked_user_list.py deleted file mode 100644 index d34605b30a..0000000000 --- a/enterprise/enterprise_hooks/blocked_user_list.py +++ /dev/null @@ -1,124 +0,0 @@ -# +------------------------------+ -# -# Blocked User List -# -# +------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan -## This accepts a list of user id's for whom calls will be rejected - - -from typing import Optional, Literal -import litellm -from litellm.proxy.utils import PrismaClient -from litellm.caching.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth, LiteLLM_EndUserTable -from litellm.integrations.custom_logger import CustomLogger -from litellm._logging import verbose_proxy_logger -from fastapi import HTTPException - - -class _ENTERPRISE_BlockedUserList(CustomLogger): - # Class variables or attributes - def __init__(self, prisma_client: Optional[PrismaClient]): - self.prisma_client = prisma_client - - blocked_user_list = litellm.blocked_user_list - if blocked_user_list is None: - self.blocked_user_list = None - return - - if isinstance(blocked_user_list, list): - self.blocked_user_list = blocked_user_list - - if isinstance(blocked_user_list, str): # assume it's a filepath - try: - with open(blocked_user_list, "r") as file: - data = file.read() - self.blocked_user_list = data.split("\n") - except FileNotFoundError: - raise Exception( - f"File not found. blocked_user_list={blocked_user_list}" - ) - except Exception as e: - raise Exception( - f"An error occurred: {str(e)}, blocked_user_list={blocked_user_list}" - ) - - def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): - if level == "INFO": - verbose_proxy_logger.info(print_statement) - elif level == "DEBUG": - verbose_proxy_logger.debug(print_statement) - - if litellm.set_verbose is True: - print(print_statement) # noqa - - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: dict, - call_type: str, - ): - try: - """ - - check if user id part of call - - check if user id part of blocked list - - if blocked list is none or user not in blocked list - - check if end-user in cache - - check if end-user in db - """ - self.print_verbose("Inside Blocked User List Pre-Call Hook") - if "user_id" in data or "user" in data: - user = data.get("user_id", data.get("user", "")) - if ( - self.blocked_user_list is not None - and user in self.blocked_user_list - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"User blocked from making LLM API Calls. User={user}" - }, - ) - - cache_key = f"litellm:end_user_id:{user}" - end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache( # type: ignore - key=cache_key - ) - if end_user_cache_obj is None and self.prisma_client is not None: - # check db - end_user_obj = ( - await self.prisma_client.db.litellm_endusertable.find_unique( - where={"user_id": user} - ) - ) - if end_user_obj is None: # user not in db - assume not blocked - end_user_obj = LiteLLM_EndUserTable(user_id=user, blocked=False) - cache.set_cache(key=cache_key, value=end_user_obj, ttl=60) - if end_user_obj is not None and end_user_obj.blocked is True: - raise HTTPException( - status_code=400, - detail={ - "error": f"User blocked from making LLM API Calls. User={user}" - }, - ) - elif ( - end_user_cache_obj is not None - and end_user_cache_obj.blocked is True - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"User blocked from making LLM API Calls. User={user}" - }, - ) - - except HTTPException as e: - raise e - except Exception as e: - verbose_proxy_logger.exception( - "litellm.enterprise.enterprise_hooks.blocked_user_list::async_pre_call_hook - Exception occurred - {}".format( - str(e) - ) - ) diff --git a/enterprise/enterprise_hooks/google_text_moderation.py b/enterprise/enterprise_hooks/google_text_moderation.py deleted file mode 100644 index 5b2d71c5cc..0000000000 --- a/enterprise/enterprise_hooks/google_text_moderation.py +++ /dev/null @@ -1,135 +0,0 @@ -# +-----------------------------------------------+ -# -# Google Text Moderation -# https://cloud.google.com/natural-language/docs/moderating-text -# -# +-----------------------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan - -from fastapi import HTTPException - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.integrations.custom_logger import CustomLogger -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.guardrails._content_utils import iter_message_text -from litellm.types.utils import CallTypesLiteral - - -class _ENTERPRISE_GoogleTextModeration(CustomLogger): - user_api_key_cache = None - confidence_categories = [ - "toxic", - "insult", - "profanity", - "derogatory", - "sexual", - "death_harm_and_tragedy", - "violent", - "firearms_and_weapons", - "public_safety", - "health", - "religion_and_belief", - "illicit_drugs", - "war_and_conflict", - "politics", - "finance", - "legal", - ] # https://cloud.google.com/natural-language/docs/moderating-text#safety_attribute_confidence_scores - - # Class variables or attributes - def __init__(self): - try: - from google.cloud import language_v1 # type: ignore - except Exception: - raise Exception( - "Missing google.cloud package. Run `pip install --upgrade google-cloud-language`" - ) - - # Instantiates a client - self.client = language_v1.LanguageServiceClient() - self.moderate_text_request = language_v1.ModerateTextRequest - self.language_document = language_v1.types.Document # type: ignore - self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore - - default_confidence_threshold = ( - litellm.google_moderation_confidence_threshold or 0.8 - ) # by default require a high confidence (80%) to fail - - for category in self.confidence_categories: - if hasattr(litellm, f"{category}_confidence_threshold"): - setattr( - self, - f"{category}_confidence_threshold", - getattr(litellm, f"{category}_confidence_threshold"), - ) - else: - setattr( - self, - f"{category}_confidence_threshold", - default_confidence_threshold, - ) - set_confidence_value = getattr( - self, - f"{category}_confidence_threshold", - ) - verbose_proxy_logger.info( - f"Google Text Moderation: {category}_confidence_threshold: {set_confidence_value}" - ) - - def print_verbose(self, print_statement): - try: - verbose_proxy_logger.debug(print_statement) - if litellm.set_verbose: - print(print_statement) # noqa - except Exception: - pass - - async def async_moderation_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - call_type: CallTypesLiteral, - ): - """ - - Calls Google's Text Moderation API - - Rejects request if it fails safety check - """ - # Covers multimodal list content + Responses-API input. - text = "".join(iter_message_text(data)) - if text: - document = self.language_document(content=text, type_=self.document_type) - - request = self.moderate_text_request( - document=document, - ) - - # Make the request - response = self.client.moderate_text(request=request) - for category in response.moderation_categories: - category_name = category.name - category_name = category_name.lower() - category_name = category_name.replace("&", "and") - category_name = category_name.replace(",", "") - category_name = category_name.replace( - " ", "_" - ) # e.g. go from 'Firearms & Weapons' to 'firearms_and_weapons' - if category.confidence > getattr( - self, f"{category_name}_confidence_threshold" - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"Violated content safety policy. Category={category}" - }, - ) - # Handle the response - return data - - -# google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration() -# asyncio.run( -# google_text_moderation_obj.async_moderation_hook( -# data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]} -# ) -# ) diff --git a/enterprise/enterprise_hooks/openai_moderation.py b/enterprise/enterprise_hooks/openai_moderation.py deleted file mode 100644 index 2162370804..0000000000 --- a/enterprise/enterprise_hooks/openai_moderation.py +++ /dev/null @@ -1,58 +0,0 @@ -# +-------------------------------------------------------------+ -# -# Use OpenAI /moderations for your LLM calls -# -# +-------------------------------------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan - -import os -import sys - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import sys - -from fastapi import HTTPException - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.integrations.custom_logger import CustomLogger -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.guardrails._content_utils import iter_message_text -from litellm.types.utils import CallTypesLiteral - - -class _ENTERPRISE_OpenAI_Moderation(CustomLogger): - def __init__(self): - self.model_name = ( - litellm.openai_moderations_model_name or "text-moderation-latest" - ) # pass the model_name you initialized on litellm.Router() - pass - - #### CALL HOOKS - proxy only #### - - async def async_moderation_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - call_type: CallTypesLiteral, - ): - # Covers multimodal list content + Responses-API input. - text = "".join(iter_message_text(data)) - - from litellm.proxy.proxy_server import llm_router - - if llm_router is None: - return - - moderation_response = await llm_router.amoderation( - model=self.model_name, input=text - ) - - verbose_proxy_logger.debug("Moderation response: %s", moderation_response) - if moderation_response and moderation_response.results[0].flagged is True: - raise HTTPException( - status_code=403, detail={"error": "Violated content safety policy"} - ) - pass diff --git a/enterprise/enterprise_ui/README.md b/enterprise/enterprise_ui/README.md deleted file mode 100644 index 88de893119..0000000000 --- a/enterprise/enterprise_ui/README.md +++ /dev/null @@ -1,6 +0,0 @@ -## Admin UI - -Customize the Admin UI to your companies branding / logo -![Group 204](https://github.com/BerriAI/litellm/assets/29436595/3b7dbfc2-6fcd-42af-996d-f734fb8f461b) - -## Docs to set up Custom Admin UI [here](https://docs.litellm.ai/docs/proxy/ui) diff --git a/enterprise/enterprise_ui/_enterprise_colors.json b/enterprise/enterprise_ui/_enterprise_colors.json deleted file mode 100644 index 4706eb1d78..0000000000 --- a/enterprise/enterprise_ui/_enterprise_colors.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "brand": { - "DEFAULT": "teal", - "faint": "teal", - "muted": "teal", - "subtle": "teal", - "emphasis": "teal", - "inverted": "teal" - } - } - \ No newline at end of file diff --git a/enterprise/litellm_enterprise/__init__.py b/enterprise/litellm_enterprise/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/callback_controls.py b/enterprise/litellm_enterprise/enterprise_callbacks/callback_controls.py deleted file mode 100644 index 8824f4c02d..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/callback_controls.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import List, Optional - -import litellm -from litellm._logging import verbose_logger -from litellm.constants import X_LITELLM_DISABLE_CALLBACKS -from litellm.integrations.custom_logger import CustomLogger -from litellm.litellm_core_utils.llm_request_utils import ( - get_proxy_server_request_headers, -) -from litellm.proxy._types import CommonProxyErrors -from litellm.types.utils import StandardCallbackDynamicParams - - -class EnterpriseCallbackControls: - @staticmethod - def is_callback_disabled_dynamically( - callback: litellm.CALLBACK_TYPES, - litellm_params: dict, - standard_callback_dynamic_params: StandardCallbackDynamicParams - ) -> bool: - """ - Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params. - - Args: - callback: The callback to check (can be string, CustomLogger instance, or callable) - litellm_params: Parameters containing proxy server request info - - Returns: - bool: True if the callback should be disabled, False otherwise - """ - from litellm.litellm_core_utils.custom_logger_registry import ( - CustomLoggerRegistry, - ) - - try: - disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params) - verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}") - verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}") - if disabled_callbacks is not None: - ######################################################### - # premium user check - ######################################################### - if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling(): - return False - ######################################################### - if isinstance(callback, str): - if callback.lower() in disabled_callbacks: - verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}") - return True - elif isinstance(callback, CustomLogger): - # get the string name of the callback - callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__) - if callback_str is not None and callback_str.lower() in disabled_callbacks: - verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}") - return True - return False - except Exception as e: - verbose_logger.debug( - f"Error checking disabled callbacks header: {str(e)}" - ) - return False - @staticmethod - def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]: - """ - Get the disabled callbacks from the standard callback dynamic params. - """ - - ######################################################### - # check if disabled via headers - ######################################################### - request_headers = get_proxy_server_request_headers(litellm_params) - disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None) - if disabled_callbacks is not None: - disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")]) - return list(disabled_callbacks) - - - ######################################################### - # check if disabled via request body - ######################################################### - if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None: - return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) - - return None - - @staticmethod - def _should_allow_dynamic_callback_disabling(): - import litellm - from litellm.proxy.proxy_server import premium_user - - # Check if admin has disabled this feature - if litellm.allow_dynamic_callback_disabling is not True: - verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling") - return False - - if premium_user: - return True - verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}") - return False \ No newline at end of file diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/example_logging_api.py b/enterprise/litellm_enterprise/enterprise_callbacks/example_logging_api.py deleted file mode 100644 index 14d34f5d1e..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/example_logging_api.py +++ /dev/null @@ -1,27 +0,0 @@ -# this is an example endpoint to receive data from litellm -from fastapi import FastAPI, HTTPException, Request - -app = FastAPI() - - -@app.post("/log-event") -async def log_event(request: Request): - try: - print("Received /log-event request") # noqa - # Assuming the incoming request has JSON data - data = await request.json() - print("Received request data:") # noqa - print(data) # noqa - - # Your additional logic can go here - # For now, just printing the received data - - return {"message": "Request received successfully"} - except Exception: - raise HTTPException(status_code=500, detail="Internal Server Error") - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="127.0.0.1", port=8000) diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/llama_guard.py b/enterprise/litellm_enterprise/enterprise_callbacks/llama_guard.py deleted file mode 100644 index 5e1aebdbdf..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/llama_guard.py +++ /dev/null @@ -1,130 +0,0 @@ -# +-------------------------------------------------------------+ -# -# Llama Guard -# https://huggingface.co/meta-llama/LlamaGuard-7b/tree/main -# -# LLM for Content Moderation -# +-------------------------------------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan - -import os -import sys -from collections.abc import Iterable - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import sys -from typing import Literal, Optional - -from fastapi import HTTPException - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.integrations.custom_logger import CustomLogger -from litellm.proxy._types import UserAPIKeyAuth -from litellm.types.utils import CallTypesLiteral, Choices, ModelResponse - - -class _ENTERPRISE_LlamaGuard(CustomLogger): - # Class variables or attributes - def __init__(self, model_name: Optional[str] = None): - _model = model_name or litellm.llamaguard_model_name - if _model is None: - raise ValueError("model_name not set for LlamaGuard") - self.model = _model - file_path = litellm.llamaguard_unsafe_content_categories - data = None - - if file_path is not None: - try: - with open(file_path, "r") as file: - data = file.read() - except FileNotFoundError: - raise Exception(f"File not found. file_path={file_path}") - except Exception as e: - raise Exception(f"An error occurred: {str(e)}, file_path={file_path}") - - self.unsafe_content_categories = data - - verbose_proxy_logger.debug( - f"self.unsafe_content_categories: {self.unsafe_content_categories}" - ) - - def print_verbose(self, print_statement): - try: - verbose_proxy_logger.debug(print_statement) - if litellm.set_verbose: - print(print_statement) # noqa - except Exception: - pass - - def set_custom_prompt_template(self, messages: list): - if self.unsafe_content_categories is not None and self.model is not None: - role = "Agent" if len(messages) % 2 == 0 else "User" - litellm.register_prompt_template( - model=self.model, - initial_prompt_value=f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories. - - - {self.unsafe_content_categories} - - - - - """, - roles={ - "system": { - "pre_message": "User: ", # [OPTIONAL] - "post_message": "\n\n", # [OPTIONAL] - }, - "user": { - "pre_message": "User: ", # [OPTIONAL] - "post_message": "\n\n", # [OPTIONAL] - }, - "assistant": { - "pre_message": "Agent: ", # [OPTIONAL] - "post_message": "\n\n", # [OPTIONAL] - }, - }, - final_prompt_value=""" - - Provide your safety assessment for {role} in the above conversation: - - First line must read 'safe' or 'unsafe'. - - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""", - ) - - async def async_moderation_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - call_type: CallTypesLiteral, - ): - """ - - Calls the Llama Guard Endpoint - - Rejects request if it fails safety check - - The llama guard prompt template is applied automatically in factory.py - """ - if "messages" in data: - safety_check_messages = data["messages"][ - -1 - ] # get the last response - llama guard has a 4k token limit - response = await litellm.acompletion( - model=self.model, - messages=[safety_check_messages], - hf_model_name="meta-llama/LlamaGuard-7b", - ) - - if ( - isinstance(response, ModelResponse) - and isinstance(response.choices[0], Choices) - and response.choices[0].message.content is not None - and isinstance(response.choices[0].message.content, Iterable) - and "unsafe" in response.choices[0].message.content - ): - raise HTTPException( - status_code=400, detail={"error": "Violated content safety policy"} - ) - - return data diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/llm_guard.py b/enterprise/litellm_enterprise/enterprise_callbacks/llm_guard.py deleted file mode 100644 index ad8aabf77b..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/llm_guard.py +++ /dev/null @@ -1,174 +0,0 @@ -# +------------------------+ -# -# LLM Guard -# https://llm-guard.com/ -# -# +------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan -## This provides an LLM Guard Integration for content moderation on the proxy - -from typing import Literal, Optional - -import aiohttp -from fastapi import HTTPException - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.integrations.custom_logger import CustomLogger -from litellm.proxy._types import UserAPIKeyAuth -from litellm.secret_managers.main import get_secret_str -from litellm.types.utils import CallTypesLiteral -from litellm.utils import get_formatted_prompt - - -class _ENTERPRISE_LLMGuard(CustomLogger): - # Class variables or attributes - def __init__( - self, - mock_testing: bool = False, - mock_redacted_text: Optional[dict] = None, - ): - self.mock_redacted_text = mock_redacted_text - self.llm_guard_mode = litellm.llm_guard_mode - if mock_testing is True: # for testing purposes only - return - self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None) - if self.llm_guard_api_base is None: - raise Exception("Missing `LLM_GUARD_API_BASE` from environment") - elif not self.llm_guard_api_base.endswith("/"): - self.llm_guard_api_base += "/" - - def print_verbose(self, print_statement): - try: - verbose_proxy_logger.debug(print_statement) - if litellm.set_verbose: - print(print_statement) # noqa - except Exception: - pass - - async def moderation_check(self, text: str): - """ - [TODO] make this more performant for high-throughput scenario - """ - try: - async with aiohttp.ClientSession() as session: - if self.mock_redacted_text is not None: - redacted_text = self.mock_redacted_text - else: - # Make the first request to /analyze - analyze_url = f"{self.llm_guard_api_base}analyze/prompt" - verbose_proxy_logger.debug("Making request to: %s", analyze_url) - analyze_payload = {"prompt": text} - redacted_text = None - async with session.post( - analyze_url, json=analyze_payload - ) as response: - redacted_text = await response.json() - verbose_proxy_logger.debug( - f"LLM Guard: Received response - {redacted_text}" - ) - if redacted_text is not None: - if ( - redacted_text.get("is_valid", None) is not None - and redacted_text["is_valid"] is False - ): - raise HTTPException( - status_code=400, - detail={"error": "Violated content safety policy"}, - ) - else: - pass - else: - raise HTTPException( - status_code=500, - detail={ - "error": f"Invalid content moderation response: {redacted_text}" - }, - ) - except Exception as e: - verbose_proxy_logger.exception( - "litellm.enterprise.enterprise_hooks.llm_guard::moderation_check - Exception occurred - {}".format( - str(e) - ) - ) - raise e - - def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool: - if self.llm_guard_mode == "key-specific": - # check if llm guard enabled for specific keys only - self.print_verbose( - f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" - ) - if ( - user_api_key_dict.permissions.get("enable_llm_guard_check", False) - is True - ): - return True - elif self.llm_guard_mode == "all": - return True - elif self.llm_guard_mode == "request-specific": - self.print_verbose(f"received metadata: {data.get('metadata', {})}") - metadata = data.get("metadata", {}) - permissions = metadata.get("permissions", {}) - if ( - "enable_llm_guard_check" in permissions - and permissions["enable_llm_guard_check"] is True - ): - return True - return False - - async def async_moderation_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - call_type: CallTypesLiteral, - ): - """ - - Calls the LLM Guard Endpoint - - Rejects request if it fails safety check - - Use the sanitized prompt returned - - LLM Guard can handle things like PII Masking, etc. - """ - self.print_verbose( - f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" - ) - - _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data) - if _proceed is False: - return - - self.print_verbose("Makes LLM Guard Check") - try: - assert call_type in [ - "completion", - "embeddings", - "image_generation", - "moderation", - "audio_transcription", - ] - except Exception: - self.print_verbose( - f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" - ) - return data - - formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore - self.print_verbose(f"LLM Guard, formatted_prompt: {formatted_prompt}") - return await self.moderation_check(text=formatted_prompt) - - async def async_post_call_streaming_hook( - self, user_api_key_dict: UserAPIKeyAuth, response: str - ): - if response is not None: - await self.moderation_check(text=response) - - return response - - -# llm_guard = _ENTERPRISE_LLMGuard() - -# asyncio.run( -# llm_guard.async_moderation_hook( -# data={"messages": [{"role": "user", "content": "Hey how's it going?"}]} -# ) -# ) diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/pagerduty/pagerduty.py b/enterprise/litellm_enterprise/enterprise_callbacks/pagerduty/pagerduty.py deleted file mode 100644 index 12fdaeb6a8..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/pagerduty/pagerduty.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -PagerDuty Alerting Integration - -Handles two types of alerts: -- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert. -- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert. - -Note: This is a Free feature on the regular litellm docker image. - -However, this is under the enterprise license -""" - -import asyncio -import os -from datetime import datetime, timedelta, timezone -from typing import List, Optional, Union - -from litellm._logging import verbose_logger -from litellm.caching import DualCache -from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting -from litellm.llms.custom_httpx.http_handler import ( - AsyncHTTPHandler, - get_async_httpx_client, - httpxSpecialProvider, -) -from litellm.proxy._types import UserAPIKeyAuth -from litellm.types.integrations.pagerduty import ( - AlertingConfig, - PagerDutyInternalEvent, - PagerDutyPayload, - PagerDutyRequestBody, -) -from litellm.types.utils import ( - CallTypesLiteral, - StandardLoggingPayload, - StandardLoggingPayloadErrorInformation, -) - -PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60 -PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60 -PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60 -PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600 - - -class PagerDutyAlerting(SlackAlerting): - """ - Tracks failed requests and hanging requests separately. - If threshold is crossed for either type, triggers a PagerDuty alert. - """ - - def __init__( - self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs - ): - super().__init__() - _api_key = os.getenv("PAGERDUTY_API_KEY") - if not _api_key: - raise ValueError("PAGERDUTY_API_KEY is not set") - - self.api_key: str = _api_key - alerting_args = alerting_args or {} - self.pagerduty_alerting_args: AlertingConfig = AlertingConfig( - failure_threshold=alerting_args.get( - "failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD - ), - failure_threshold_window_seconds=alerting_args.get( - "failure_threshold_window_seconds", - PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS, - ), - hanging_threshold_seconds=alerting_args.get( - "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS - ), - hanging_threshold_window_seconds=alerting_args.get( - "hanging_threshold_window_seconds", - PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, - ), - ) - - # Separate storage for failures vs. hangs - self._failure_events: List[PagerDutyInternalEvent] = [] - self._hanging_events: List[PagerDutyInternalEvent] = [] - - # ------------------ MAIN LOGIC ------------------ # - - async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): - """ - Record a failure event. Only send an alert to PagerDuty if the - configured *failure* threshold is exceeded in the specified window. - """ - now = datetime.now(timezone.utc) - standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object" - ) - if not standard_logging_payload: - raise ValueError( - "standard_logging_object is required for PagerDutyAlerting" - ) - - # Extract error details - error_info: Optional[StandardLoggingPayloadErrorInformation] = ( - standard_logging_payload.get("error_information") or {} - ) - _meta = standard_logging_payload.get("metadata") or {} - - self._failure_events.append( - PagerDutyInternalEvent( - failure_event_type="failed_response", - timestamp=now, - error_class=error_info.get("error_class"), - error_code=error_info.get("error_code"), - error_llm_provider=error_info.get("llm_provider"), - user_api_key_hash=_meta.get("user_api_key_hash"), - user_api_key_alias=_meta.get("user_api_key_alias"), - user_api_key_spend=_meta.get("user_api_key_spend"), - user_api_key_max_budget=_meta.get("user_api_key_max_budget"), - user_api_key_budget_reset_at=_meta.get("user_api_key_budget_reset_at"), - user_api_key_org_id=_meta.get("user_api_key_org_id"), - user_api_key_org_alias=_meta.get("user_api_key_org_alias"), - user_api_key_team_id=_meta.get("user_api_key_team_id"), - user_api_key_project_id=_meta.get("user_api_key_project_id"), - user_api_key_project_alias=_meta.get("user_api_key_project_alias"), - user_api_key_user_id=_meta.get("user_api_key_user_id"), - user_api_key_team_alias=_meta.get("user_api_key_team_alias"), - user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"), - user_api_key_user_email=_meta.get("user_api_key_user_email"), - user_api_key_request_route=_meta.get("user_api_key_request_route"), - user_api_key_auth_metadata=_meta.get("user_api_key_auth_metadata"), - ) - ) - - # Prune + Possibly alert - window_seconds = self.pagerduty_alerting_args.get( - "failure_threshold_window_seconds", 60 - ) - threshold = self.pagerduty_alerting_args.get("failure_threshold", 1) - - # If threshold is crossed, send PD alert for failures - await self._send_alert_if_thresholds_crossed( - events=self._failure_events, - window_seconds=window_seconds, - threshold=threshold, - alert_prefix="High LLM API Failure Rate", - ) - - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: dict, - call_type: CallTypesLiteral, - ) -> Optional[Union[Exception, str, dict]]: - """ - Example of detecting hanging requests by waiting a given threshold. - If the request didn't finish by then, we treat it as 'hanging'. - """ - verbose_logger.info("Inside Proxy Logging Pre-call hook!") - asyncio.create_task( - self.hanging_response_handler( - request_data=data, user_api_key_dict=user_api_key_dict - ) - ) - return None - - async def hanging_response_handler( - self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth - ): - """ - Checks if request completed by the time 'hanging_threshold_seconds' elapses. - If not, we classify it as a hanging request. - """ - verbose_logger.debug( - f"Inside Hanging Response Handler!..sleeping for {self.pagerduty_alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds" - ) - await asyncio.sleep( - self.pagerduty_alerting_args.get( - "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS - ) - ) - - if await self._request_is_completed(request_data=request_data): - return # It's not hanging if completed - - # Otherwise, record it as hanging - self._hanging_events.append( - PagerDutyInternalEvent( - failure_event_type="hanging_response", - timestamp=datetime.now(timezone.utc), - error_class="HangingRequest", - error_code="HangingRequest", - error_llm_provider="HangingRequest", - user_api_key_hash=user_api_key_dict.api_key, - user_api_key_alias=user_api_key_dict.key_alias, - user_api_key_spend=user_api_key_dict.spend, - user_api_key_max_budget=user_api_key_dict.max_budget, - user_api_key_budget_reset_at=( - user_api_key_dict.budget_reset_at.isoformat() - if user_api_key_dict.budget_reset_at - else None - ), - user_api_key_org_id=user_api_key_dict.org_id, - user_api_key_org_alias=user_api_key_dict.organization_alias, - user_api_key_team_id=user_api_key_dict.team_id, - user_api_key_project_id=user_api_key_dict.project_id, - user_api_key_project_alias=user_api_key_dict.project_alias, - user_api_key_user_id=user_api_key_dict.user_id, - user_api_key_team_alias=user_api_key_dict.team_alias, - user_api_key_end_user_id=user_api_key_dict.end_user_id, - user_api_key_user_email=user_api_key_dict.user_email, - user_api_key_request_route=user_api_key_dict.request_route, - user_api_key_auth_metadata=user_api_key_dict.metadata, - ) - ) - - # Prune + Possibly alert - window_seconds = self.pagerduty_alerting_args.get( - "hanging_threshold_window_seconds", - PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, - ) - threshold: int = self.pagerduty_alerting_args.get( - "hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS - ) - - # If threshold is crossed, send PD alert for hangs - await self._send_alert_if_thresholds_crossed( - events=self._hanging_events, - window_seconds=window_seconds, - threshold=threshold, - alert_prefix="High Number of Hanging LLM Requests", - ) - - # ------------------ HELPERS ------------------ # - - async def _send_alert_if_thresholds_crossed( - self, - events: List[PagerDutyInternalEvent], - window_seconds: int, - threshold: int, - alert_prefix: str, - ): - """ - 1. Prune old events - 2. If threshold is reached, build alert, send to PagerDuty - 3. Clear those events - """ - cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds) - pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff] - - # Update the reference list - events.clear() - events.extend(pruned) - - # Check threshold - verbose_logger.debug( - f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}" - ) - if len(events) >= threshold: - # Build short summary of last N events - error_summaries = self._build_error_summaries(events, max_errors=5) - alert_message = ( - f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds." - ) - custom_details = {"recent_errors": error_summaries} - - await self.send_alert_to_pagerduty( - alert_message=alert_message, - custom_details=custom_details, - ) - - # Clear them after sending an alert, so we don't spam - events.clear() - - def _build_error_summaries( - self, events: List[PagerDutyInternalEvent], max_errors: int = 5 - ) -> List[PagerDutyInternalEvent]: - """ - Build short text summaries for the last `max_errors`. - Example: "ValueError (code: 500, provider: openai)" - """ - recent = events[-max_errors:] - summaries = [] - for fe in recent: - # If any of these is None, show "N/A" to avoid messing up the summary string - fe.pop("timestamp") - summaries.append(fe) - return summaries - - async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict): - """ - Send [critical] Alert to PagerDuty - - https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api - """ - try: - verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}") - async_client: AsyncHTTPHandler = get_async_httpx_client( - llm_provider=httpxSpecialProvider.LoggingCallback - ) - payload: PagerDutyRequestBody = PagerDutyRequestBody( - payload=PagerDutyPayload( - summary=alert_message, - severity="critical", - source="LiteLLM Alert", - component="LiteLLM", - custom_details=custom_details, - ), - routing_key=self.api_key, - event_action="trigger", - ) - - return await async_client.post( - url="https://events.pagerduty.com/v2/enqueue", - json=dict(payload), - headers={"Content-Type": "application/json"}, - ) - except Exception as e: - verbose_logger.exception(f"Error sending alert to PagerDuty: {e}") diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secret_detection.py b/enterprise/litellm_enterprise/enterprise_callbacks/secret_detection.py deleted file mode 100644 index f441ce71ab..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secret_detection.py +++ /dev/null @@ -1,523 +0,0 @@ -# +-------------------------------------------------------------+ -# -# Use SecretDetection /moderations for your LLM calls -# -# +-------------------------------------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan - -import os -import sys - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import tempfile -from typing import Optional - -from litellm._logging import verbose_proxy_logger -from litellm.caching.caching import DualCache -from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.guardrails._content_utils import walk_user_text - -GUARDRAIL_NAME = "hide_secrets" - -_custom_plugins_path = "file://" + os.path.join( - os.path.dirname(os.path.abspath(__file__)), "secrets_plugins" -) -_default_detect_secrets_config = { - "plugins_used": [ - {"name": "SoftlayerDetector"}, - {"name": "StripeDetector"}, - {"name": "NpmDetector"}, - {"name": "IbmCosHmacDetector"}, - {"name": "DiscordBotTokenDetector"}, - {"name": "BasicAuthDetector"}, - {"name": "AzureStorageKeyDetector"}, - {"name": "ArtifactoryDetector"}, - {"name": "AWSKeyDetector"}, - {"name": "CloudantDetector"}, - {"name": "IbmCloudIamDetector"}, - {"name": "JwtTokenDetector"}, - {"name": "MailchimpDetector"}, - {"name": "SquareOAuthDetector"}, - {"name": "PrivateKeyDetector"}, - {"name": "TwilioKeyDetector"}, - { - "name": "AdafruitKeyDetector", - "path": _custom_plugins_path + "/adafruit.py", - }, - { - "name": "AdobeSecretDetector", - "path": _custom_plugins_path + "/adobe.py", - }, - { - "name": "AgeSecretKeyDetector", - "path": _custom_plugins_path + "/age_secret_key.py", - }, - { - "name": "AirtableApiKeyDetector", - "path": _custom_plugins_path + "/airtable_api_key.py", - }, - { - "name": "AlgoliaApiKeyDetector", - "path": _custom_plugins_path + "/algolia_api_key.py", - }, - { - "name": "AlibabaSecretDetector", - "path": _custom_plugins_path + "/alibaba.py", - }, - { - "name": "AsanaSecretDetector", - "path": _custom_plugins_path + "/asana.py", - }, - { - "name": "AtlassianApiTokenDetector", - "path": _custom_plugins_path + "/atlassian_api_token.py", - }, - { - "name": "AuthressAccessKeyDetector", - "path": _custom_plugins_path + "/authress_access_key.py", - }, - { - "name": "BittrexDetector", - "path": _custom_plugins_path + "/beamer_api_token.py", - }, - { - "name": "BitbucketDetector", - "path": _custom_plugins_path + "/bitbucket.py", - }, - { - "name": "BeamerApiTokenDetector", - "path": _custom_plugins_path + "/bittrex.py", - }, - { - "name": "ClojarsApiTokenDetector", - "path": _custom_plugins_path + "/clojars_api_token.py", - }, - { - "name": "CodecovAccessTokenDetector", - "path": _custom_plugins_path + "/codecov_access_token.py", - }, - { - "name": "CoinbaseAccessTokenDetector", - "path": _custom_plugins_path + "/coinbase_access_token.py", - }, - { - "name": "ConfluentDetector", - "path": _custom_plugins_path + "/confluent.py", - }, - { - "name": "ContentfulApiTokenDetector", - "path": _custom_plugins_path + "/contentful_api_token.py", - }, - { - "name": "DatabricksApiTokenDetector", - "path": _custom_plugins_path + "/databricks_api_token.py", - }, - { - "name": "DatadogAccessTokenDetector", - "path": _custom_plugins_path + "/datadog_access_token.py", - }, - { - "name": "DefinedNetworkingApiTokenDetector", - "path": _custom_plugins_path + "/defined_networking_api_token.py", - }, - { - "name": "DigitaloceanDetector", - "path": _custom_plugins_path + "/digitalocean.py", - }, - { - "name": "DopplerApiTokenDetector", - "path": _custom_plugins_path + "/doppler_api_token.py", - }, - { - "name": "DroneciAccessTokenDetector", - "path": _custom_plugins_path + "/droneci_access_token.py", - }, - { - "name": "DuffelApiTokenDetector", - "path": _custom_plugins_path + "/duffel_api_token.py", - }, - { - "name": "DynatraceApiTokenDetector", - "path": _custom_plugins_path + "/dynatrace_api_token.py", - }, - { - "name": "DiscordDetector", - "path": _custom_plugins_path + "/discord.py", - }, - { - "name": "DropboxDetector", - "path": _custom_plugins_path + "/dropbox.py", - }, - { - "name": "EasyPostDetector", - "path": _custom_plugins_path + "/easypost.py", - }, - { - "name": "EtsyAccessTokenDetector", - "path": _custom_plugins_path + "/etsy_access_token.py", - }, - { - "name": "FacebookAccessTokenDetector", - "path": _custom_plugins_path + "/facebook_access_token.py", - }, - { - "name": "FastlyApiKeyDetector", - "path": _custom_plugins_path + "/fastly_api_token.py", - }, - { - "name": "FinicityDetector", - "path": _custom_plugins_path + "/finicity.py", - }, - { - "name": "FinnhubAccessTokenDetector", - "path": _custom_plugins_path + "/finnhub_access_token.py", - }, - { - "name": "FlickrAccessTokenDetector", - "path": _custom_plugins_path + "/flickr_access_token.py", - }, - { - "name": "FlutterwaveDetector", - "path": _custom_plugins_path + "/flutterwave.py", - }, - { - "name": "FrameIoApiTokenDetector", - "path": _custom_plugins_path + "/frameio_api_token.py", - }, - { - "name": "FreshbooksAccessTokenDetector", - "path": _custom_plugins_path + "/freshbooks_access_token.py", - }, - { - "name": "GCPApiKeyDetector", - "path": _custom_plugins_path + "/gcp_api_key.py", - }, - { - "name": "GitHubTokenCustomDetector", - "path": _custom_plugins_path + "/github_token.py", - }, - { - "name": "GitLabDetector", - "path": _custom_plugins_path + "/gitlab.py", - }, - { - "name": "GitterAccessTokenDetector", - "path": _custom_plugins_path + "/gitter_access_token.py", - }, - { - "name": "GoCardlessApiTokenDetector", - "path": _custom_plugins_path + "/gocardless_api_token.py", - }, - { - "name": "GrafanaDetector", - "path": _custom_plugins_path + "/grafana.py", - }, - { - "name": "HashiCorpTFApiTokenDetector", - "path": _custom_plugins_path + "/hashicorp_tf_api_token.py", - }, - { - "name": "HerokuApiKeyDetector", - "path": _custom_plugins_path + "/heroku_api_key.py", - }, - { - "name": "HubSpotApiTokenDetector", - "path": _custom_plugins_path + "/hubspot_api_key.py", - }, - { - "name": "HuggingFaceDetector", - "path": _custom_plugins_path + "/huggingface.py", - }, - { - "name": "IntercomApiTokenDetector", - "path": _custom_plugins_path + "/intercom_api_key.py", - }, - { - "name": "JFrogDetector", - "path": _custom_plugins_path + "/jfrog.py", - }, - { - "name": "JWTBase64Detector", - "path": _custom_plugins_path + "/jwt.py", - }, - { - "name": "KrakenAccessTokenDetector", - "path": _custom_plugins_path + "/kraken_access_token.py", - }, - { - "name": "KucoinDetector", - "path": _custom_plugins_path + "/kucoin.py", - }, - { - "name": "LaunchdarklyAccessTokenDetector", - "path": _custom_plugins_path + "/launchdarkly_access_token.py", - }, - { - "name": "LinearDetector", - "path": _custom_plugins_path + "/linear.py", - }, - { - "name": "LinkedInDetector", - "path": _custom_plugins_path + "/linkedin.py", - }, - { - "name": "LobDetector", - "path": _custom_plugins_path + "/lob.py", - }, - { - "name": "MailgunDetector", - "path": _custom_plugins_path + "/mailgun.py", - }, - { - "name": "MapBoxApiTokenDetector", - "path": _custom_plugins_path + "/mapbox_api_token.py", - }, - { - "name": "MattermostAccessTokenDetector", - "path": _custom_plugins_path + "/mattermost_access_token.py", - }, - { - "name": "MessageBirdDetector", - "path": _custom_plugins_path + "/messagebird.py", - }, - { - "name": "MicrosoftTeamsWebhookDetector", - "path": _custom_plugins_path + "/microsoft_teams_webhook.py", - }, - { - "name": "NetlifyAccessTokenDetector", - "path": _custom_plugins_path + "/netlify_access_token.py", - }, - { - "name": "NewRelicDetector", - "path": _custom_plugins_path + "/new_relic.py", - }, - { - "name": "NYTimesAccessTokenDetector", - "path": _custom_plugins_path + "/nytimes_access_token.py", - }, - { - "name": "OktaAccessTokenDetector", - "path": _custom_plugins_path + "/okta_access_token.py", - }, - { - "name": "OpenAIApiKeyDetector", - "path": _custom_plugins_path + "/openai_api_key.py", - }, - { - "name": "PlanetScaleDetector", - "path": _custom_plugins_path + "/planetscale.py", - }, - { - "name": "PostmanApiTokenDetector", - "path": _custom_plugins_path + "/postman_api_token.py", - }, - { - "name": "PrefectApiTokenDetector", - "path": _custom_plugins_path + "/prefect_api_token.py", - }, - { - "name": "PulumiApiTokenDetector", - "path": _custom_plugins_path + "/pulumi_api_token.py", - }, - { - "name": "PyPiUploadTokenDetector", - "path": _custom_plugins_path + "/pypi_upload_token.py", - }, - { - "name": "RapidApiAccessTokenDetector", - "path": _custom_plugins_path + "/rapidapi_access_token.py", - }, - { - "name": "ReadmeApiTokenDetector", - "path": _custom_plugins_path + "/readme_api_token.py", - }, - { - "name": "RubygemsApiTokenDetector", - "path": _custom_plugins_path + "/rubygems_api_token.py", - }, - { - "name": "ScalingoApiTokenDetector", - "path": _custom_plugins_path + "/scalingo_api_token.py", - }, - { - "name": "SendbirdDetector", - "path": _custom_plugins_path + "/sendbird.py", - }, - { - "name": "SendGridApiTokenDetector", - "path": _custom_plugins_path + "/sendgrid_api_token.py", - }, - { - "name": "SendinBlueApiTokenDetector", - "path": _custom_plugins_path + "/sendinblue_api_token.py", - }, - { - "name": "SentryAccessTokenDetector", - "path": _custom_plugins_path + "/sentry_access_token.py", - }, - { - "name": "ShippoApiTokenDetector", - "path": _custom_plugins_path + "/shippo_api_token.py", - }, - { - "name": "ShopifyDetector", - "path": _custom_plugins_path + "/shopify.py", - }, - { - "name": "SlackDetector", - "path": _custom_plugins_path + "/slack.py", - }, - { - "name": "SnykApiTokenDetector", - "path": _custom_plugins_path + "/snyk_api_token.py", - }, - { - "name": "SquarespaceAccessTokenDetector", - "path": _custom_plugins_path + "/squarespace_access_token.py", - }, - { - "name": "SumoLogicDetector", - "path": _custom_plugins_path + "/sumologic.py", - }, - { - "name": "TelegramBotApiTokenDetector", - "path": _custom_plugins_path + "/telegram_bot_api_token.py", - }, - { - "name": "TravisCiAccessTokenDetector", - "path": _custom_plugins_path + "/travisci_access_token.py", - }, - { - "name": "TwitchApiTokenDetector", - "path": _custom_plugins_path + "/twitch_api_token.py", - }, - { - "name": "TwitterDetector", - "path": _custom_plugins_path + "/twitter.py", - }, - { - "name": "TypeformApiTokenDetector", - "path": _custom_plugins_path + "/typeform_api_token.py", - }, - { - "name": "VaultDetector", - "path": _custom_plugins_path + "/vault.py", - }, - { - "name": "YandexDetector", - "path": _custom_plugins_path + "/yandex.py", - }, - { - "name": "ZendeskSecretKeyDetector", - "path": _custom_plugins_path + "/zendesk_secret_key.py", - }, - {"name": "Base64HighEntropyString", "limit": 3.0}, - {"name": "HexHighEntropyString", "limit": 3.0}, - ] -} - - -class _ENTERPRISE_SecretDetection(CustomGuardrail): - def __init__(self, detect_secrets_config: Optional[dict] = None, **kwargs): - self.user_defined_detect_secrets_config = detect_secrets_config - super().__init__(**kwargs) - - def scan_message_for_secrets(self, message_content: str): - from detect_secrets import SecretsCollection - from detect_secrets.settings import transient_settings - - temp_file = tempfile.NamedTemporaryFile(delete=False) - temp_file.write(message_content.encode("utf-8")) - temp_file.close() - - secrets = SecretsCollection() - - detect_secrets_config = ( - self.user_defined_detect_secrets_config or _default_detect_secrets_config - ) - with transient_settings(detect_secrets_config): - secrets.scan_file(temp_file.name) - - os.remove(temp_file.name) - - detected_secrets = [] - for file in secrets.files: - for found_secret in secrets[file]: - if found_secret.secret_value is None: - continue - detected_secrets.append( - {"type": found_secret.type, "value": found_secret.secret_value} - ) - - return detected_secrets - - async def should_run_check(self, user_api_key_dict: UserAPIKeyAuth) -> bool: - if user_api_key_dict.permissions is not None: - if GUARDRAIL_NAME in user_api_key_dict.permissions: - if user_api_key_dict.permissions[GUARDRAIL_NAME] is False: - return False - - return True - - #### CALL HOOKS - proxy only #### - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: dict, - call_type: str, # "completion", "embeddings", "image_generation", "moderation" - ): - if await self.should_run_check(user_api_key_dict) is False: - return - - # Covers multimodal list content + Responses-API input. - def _redact_message_text(text: str) -> str: - detected_secrets = self.scan_message_for_secrets(text) - for secret in detected_secrets: - text = text.replace(secret["value"], "[REDACTED]") - if detected_secrets: - secret_types = [secret["type"] for secret in detected_secrets] - verbose_proxy_logger.warning( - f"Detected and redacted secrets in message: {secret_types}" - ) - return text - - walk_user_text(data, _redact_message_text) - - if "prompt" in data: - if isinstance(data["prompt"], str): - detected_secrets = self.scan_message_for_secrets(data["prompt"]) - for secret in detected_secrets: - data["prompt"] = data["prompt"].replace( - secret["value"], "[REDACTED]" - ) - if len(detected_secrets) > 0: - secret_types = [secret["type"] for secret in detected_secrets] - verbose_proxy_logger.warning( - f"Detected and redacted secrets in prompt: {secret_types}" - ) - elif isinstance(data["prompt"], list): - # Index back into the list — assigning to ``item`` would only - # rebind the loop variable and leave ``data["prompt"]`` - # carrying the unredacted secret. - for idx, item in enumerate(data["prompt"]): - if isinstance(item, str): - detected_secrets = self.scan_message_for_secrets(item) - for secret in detected_secrets: - item = item.replace(secret["value"], "[REDACTED]") - data["prompt"][idx] = item - if len(detected_secrets) > 0: - secret_types = [ - secret["type"] for secret in detected_secrets - ] - verbose_proxy_logger.warning( - f"Detected and redacted secrets in prompt: {secret_types}" - ) - - # ``data["input"]`` (Responses API and embeddings/moderation) is - # already covered by ``walk_user_text`` above. - return diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/__init__.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/adafruit.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/adafruit.py deleted file mode 100644 index abee3398f3..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/adafruit.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Adafruit keys -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AdafruitKeyDetector(RegexBasedDetector): - """Scans for Adafruit keys.""" - - @property - def secret_type(self) -> str: - return "Adafruit API Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:adafruit)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9_-]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/adobe.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/adobe.py deleted file mode 100644 index 7a58ccdf90..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/adobe.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -This plugin searches for Adobe keys -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AdobeSecretDetector(RegexBasedDetector): - """Scans for Adobe client keys.""" - - @property - def secret_type(self) -> str: - return "Adobe Client Keys" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Adobe Client ID (OAuth Web) - re.compile( - r"""(?i)(?:adobe)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Adobe Client Secret - re.compile(r"(?i)\b((p8e-)[a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)"), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/age_secret_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/age_secret_key.py deleted file mode 100644 index 2c0c179102..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/age_secret_key.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -This plugin searches for Age secret keys -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AgeSecretKeyDetector(RegexBasedDetector): - """Scans for Age secret keys.""" - - @property - def secret_type(self) -> str: - return "Age Secret Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile(r"""AGE-SECRET-KEY-1[QPZRY9X8GF2TVDW0S3JN54KHCE6MUA7L]{58}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/airtable_api_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/airtable_api_key.py deleted file mode 100644 index 8abf4f6e44..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/airtable_api_key.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Airtable API keys -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AirtableApiKeyDetector(RegexBasedDetector): - """Scans for Airtable API keys.""" - - @property - def secret_type(self) -> str: - return "Airtable API Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:airtable)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{17})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/algolia_api_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/algolia_api_key.py deleted file mode 100644 index cd6c16a8c0..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/algolia_api_key.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -This plugin searches for Algolia API keys -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AlgoliaApiKeyDetector(RegexBasedDetector): - """Scans for Algolia API keys.""" - - @property - def secret_type(self) -> str: - return "Algolia API Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile(r"""(?i)\b((LTAI)[a-z0-9]{20})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/alibaba.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/alibaba.py deleted file mode 100644 index 5d071f1a9b..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/alibaba.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -This plugin searches for Alibaba secrets -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AlibabaSecretDetector(RegexBasedDetector): - """Scans for Alibaba AccessKey IDs and Secret Keys.""" - - @property - def secret_type(self) -> str: - return "Alibaba Secrets" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Alibaba AccessKey ID - re.compile(r"""(?i)\b((LTAI)[a-z0-9]{20})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), - # For Alibaba Secret Key - re.compile( - r"""(?i)(?:alibaba)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{30})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/asana.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/asana.py deleted file mode 100644 index fd96872c63..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/asana.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for Asana secrets -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AsanaSecretDetector(RegexBasedDetector): - """Scans for Asana Client IDs and Client Secrets.""" - - @property - def secret_type(self) -> str: - return "Asana Secrets" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Asana Client ID - re.compile( - r"""(?i)(?:asana)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9]{16})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # For Asana Client Secret - re.compile( - r"""(?i)(?:asana)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/atlassian_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/atlassian_api_token.py deleted file mode 100644 index 42fd291ff4..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/atlassian_api_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Atlassian API tokens -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AtlassianApiTokenDetector(RegexBasedDetector): - """Scans for Atlassian API tokens.""" - - @property - def secret_type(self) -> str: - return "Atlassian API token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Atlassian API token - re.compile( - r"""(?i)(?:atlassian|confluence|jira)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{24})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/authress_access_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/authress_access_key.py deleted file mode 100644 index ff7466fc44..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/authress_access_key.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Authress Service Client Access Keys -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class AuthressAccessKeyDetector(RegexBasedDetector): - """Scans for Authress Service Client Access Keys.""" - - @property - def secret_type(self) -> str: - return "Authress Service Client Access Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Authress Service Client Access Key - re.compile( - r"""(?i)\b((?:sc|ext|scauth|authress)_[a-z0-9]{5,30}\.[a-z0-9]{4,6}\.acc[_-][a-z0-9-]{10,32}\.[a-z0-9+/_=-]{30,120})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/beamer_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/beamer_api_token.py deleted file mode 100644 index 5303e6262f..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/beamer_api_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Beamer API tokens -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class BeamerApiTokenDetector(RegexBasedDetector): - """Scans for Beamer API tokens.""" - - @property - def secret_type(self) -> str: - return "Beamer API token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Beamer API token - re.compile( - r"""(?i)(?:beamer)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(b_[a-z0-9=_\-]{44})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/bitbucket.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/bitbucket.py deleted file mode 100644 index aae28dcc7d..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/bitbucket.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for Bitbucket Client ID and Client Secret -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class BitbucketDetector(RegexBasedDetector): - """Scans for Bitbucket Client ID and Client Secret.""" - - @property - def secret_type(self) -> str: - return "Bitbucket Secrets" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Bitbucket Client ID - re.compile( - r"""(?i)(?:bitbucket)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # For Bitbucket Client Secret - re.compile( - r"""(?i)(?:bitbucket)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/bittrex.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/bittrex.py deleted file mode 100644 index e8bd3347bb..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/bittrex.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for Bittrex Access Key and Secret Key -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class BittrexDetector(RegexBasedDetector): - """Scans for Bittrex Access Key and Secret Key.""" - - @property - def secret_type(self) -> str: - return "Bittrex Secrets" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Bittrex Access Key - re.compile( - r"""(?i)(?:bittrex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # For Bittrex Secret Key - re.compile( - r"""(?i)(?:bittrex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/clojars_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/clojars_api_token.py deleted file mode 100644 index 6eb41ec4bb..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/clojars_api_token.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -This plugin searches for Clojars API tokens -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class ClojarsApiTokenDetector(RegexBasedDetector): - """Scans for Clojars API tokens.""" - - @property - def secret_type(self) -> str: - return "Clojars API token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Clojars API token - re.compile(r"(?i)(CLOJARS_)[a-z0-9]{60}"), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/codecov_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/codecov_access_token.py deleted file mode 100644 index 51001675f0..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/codecov_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Codecov Access Token -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class CodecovAccessTokenDetector(RegexBasedDetector): - """Scans for Codecov Access Token.""" - - @property - def secret_type(self) -> str: - return "Codecov Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Codecov Access Token - re.compile( - r"""(?i)(?:codecov)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/coinbase_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/coinbase_access_token.py deleted file mode 100644 index 0af631be99..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/coinbase_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Coinbase Access Token -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class CoinbaseAccessTokenDetector(RegexBasedDetector): - """Scans for Coinbase Access Token.""" - - @property - def secret_type(self) -> str: - return "Coinbase Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Coinbase Access Token - re.compile( - r"""(?i)(?:coinbase)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9_-]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/confluent.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/confluent.py deleted file mode 100644 index aefbd42b94..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/confluent.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for Confluent Access Token and Confluent Secret Key -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class ConfluentDetector(RegexBasedDetector): - """Scans for Confluent Access Token and Confluent Secret Key.""" - - @property - def secret_type(self) -> str: - return "Confluent Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # For Confluent Access Token - re.compile( - r"""(?i)(?:confluent)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{16})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # For Confluent Secret Key - re.compile( - r"""(?i)(?:confluent)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/contentful_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/contentful_api_token.py deleted file mode 100644 index 33817dc4d8..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/contentful_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Contentful delivery API token. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class ContentfulApiTokenDetector(RegexBasedDetector): - """Scans for Contentful delivery API token.""" - - @property - def secret_type(self) -> str: - return "Contentful API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:contentful)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{43})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/databricks_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/databricks_api_token.py deleted file mode 100644 index 9e47355b1c..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/databricks_api_token.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -This plugin searches for Databricks API token. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DatabricksApiTokenDetector(RegexBasedDetector): - """Scans for Databricks API token.""" - - @property - def secret_type(self) -> str: - return "Databricks API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile(r"""(?i)\b(dapi[a-h0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/datadog_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/datadog_access_token.py deleted file mode 100644 index bdb430d9bc..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/datadog_access_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Datadog Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DatadogAccessTokenDetector(RegexBasedDetector): - """Scans for Datadog Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Datadog Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:datadog)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/defined_networking_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/defined_networking_api_token.py deleted file mode 100644 index b23cdb4543..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/defined_networking_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Defined Networking API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DefinedNetworkingApiTokenDetector(RegexBasedDetector): - """Scans for Defined Networking API Tokens.""" - - @property - def secret_type(self) -> str: - return "Defined Networking API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:dnkey)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(dnkey-[a-z0-9=_\-]{26}-[a-z0-9=_\-]{52})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/digitalocean.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/digitalocean.py deleted file mode 100644 index 5ffc4f600e..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/digitalocean.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -This plugin searches for DigitalOcean tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DigitaloceanDetector(RegexBasedDetector): - """Scans for various DigitalOcean Tokens.""" - - @property - def secret_type(self) -> str: - return "DigitalOcean Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # OAuth Access Token - re.compile(r"""(?i)\b(doo_v1_[a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), - # Personal Access Token - re.compile(r"""(?i)\b(dop_v1_[a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), - # OAuth Refresh Token - re.compile(r"""(?i)\b(dor_v1_[a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/discord.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/discord.py deleted file mode 100644 index c51406b606..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/discord.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -This plugin searches for Discord Client tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DiscordDetector(RegexBasedDetector): - """Scans for various Discord Client Tokens.""" - - @property - def secret_type(self) -> str: - return "Discord Client Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Discord API key - re.compile( - r"""(?i)(?:discord)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Discord client ID - re.compile( - r"""(?i)(?:discord)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9]{18})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Discord client secret - re.compile( - r"""(?i)(?:discord)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/doppler_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/doppler_api_token.py deleted file mode 100644 index 56c594fc1f..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/doppler_api_token.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -This plugin searches for Doppler API tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DopplerApiTokenDetector(RegexBasedDetector): - """Scans for Doppler API Tokens.""" - - @property - def secret_type(self) -> str: - return "Doppler API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Doppler API token - re.compile(r"""(?i)dp\.pt\.[a-z0-9]{43}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/droneci_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/droneci_access_token.py deleted file mode 100644 index 8afffb8026..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/droneci_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Droneci Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DroneciAccessTokenDetector(RegexBasedDetector): - """Scans for Droneci Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Droneci Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Droneci Access Token - re.compile( - r"""(?i)(?:droneci)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/dropbox.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/dropbox.py deleted file mode 100644 index b19815b26d..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/dropbox.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -This plugin searches for Dropbox tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DropboxDetector(RegexBasedDetector): - """Scans for various Dropbox Tokens.""" - - @property - def secret_type(self) -> str: - return "Dropbox Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Dropbox API secret - re.compile( - r"""(?i)(?:dropbox)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{15})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Dropbox long-lived API token - re.compile( - r"""(?i)(?:dropbox)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{11}(AAAAAAAAAA)[a-z0-9\-_=]{43})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Dropbox short-lived API token - re.compile( - r"""(?i)(?:dropbox)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(sl\.[a-z0-9\-=_]{135})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/duffel_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/duffel_api_token.py deleted file mode 100644 index aab681598c..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/duffel_api_token.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -This plugin searches for Duffel API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DuffelApiTokenDetector(RegexBasedDetector): - """Scans for Duffel API Tokens.""" - - @property - def secret_type(self) -> str: - return "Duffel API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Duffel API Token - re.compile(r"""(?i)duffel_(test|live)_[a-z0-9_\-=]{43}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/dynatrace_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/dynatrace_api_token.py deleted file mode 100644 index caf7dd7197..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/dynatrace_api_token.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -This plugin searches for Dynatrace API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class DynatraceApiTokenDetector(RegexBasedDetector): - """Scans for Dynatrace API Tokens.""" - - @property - def secret_type(self) -> str: - return "Dynatrace API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Dynatrace API Token - re.compile(r"""(?i)dt0c01\.[a-z0-9]{24}\.[a-z0-9]{64}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/easypost.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/easypost.py deleted file mode 100644 index 73d27cb491..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/easypost.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for EasyPost tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class EasyPostDetector(RegexBasedDetector): - """Scans for various EasyPost Tokens.""" - - @property - def secret_type(self) -> str: - return "EasyPost Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # EasyPost API token - re.compile(r"""(?i)\bEZAK[a-z0-9]{54}"""), - # EasyPost test API token - re.compile(r"""(?i)\bEZTK[a-z0-9]{54}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/etsy_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/etsy_access_token.py deleted file mode 100644 index 1775a4b41d..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/etsy_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Etsy Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class EtsyAccessTokenDetector(RegexBasedDetector): - """Scans for Etsy Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Etsy Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Etsy Access Token - re.compile( - r"""(?i)(?:etsy)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{24})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/facebook_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/facebook_access_token.py deleted file mode 100644 index edc7d080c6..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/facebook_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Facebook Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class FacebookAccessTokenDetector(RegexBasedDetector): - """Scans for Facebook Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Facebook Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Facebook Access Token - re.compile( - r"""(?i)(?:facebook)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/fastly_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/fastly_api_token.py deleted file mode 100644 index 4d451cb746..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/fastly_api_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Fastly API keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class FastlyApiKeyDetector(RegexBasedDetector): - """Scans for Fastly API keys.""" - - @property - def secret_type(self) -> str: - return "Fastly API Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Fastly API key - re.compile( - r"""(?i)(?:fastly)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/finicity.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/finicity.py deleted file mode 100644 index 97414352fc..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/finicity.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for Finicity API tokens and Client Secrets. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class FinicityDetector(RegexBasedDetector): - """Scans for Finicity API tokens and Client Secrets.""" - - @property - def secret_type(self) -> str: - return "Finicity Credentials" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Finicity API token - re.compile( - r"""(?i)(?:finicity)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Finicity Client Secret - re.compile( - r"""(?i)(?:finicity)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{20})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/finnhub_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/finnhub_access_token.py deleted file mode 100644 index eeb09682b0..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/finnhub_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Finnhub Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class FinnhubAccessTokenDetector(RegexBasedDetector): - """Scans for Finnhub Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Finnhub Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Finnhub Access Token - re.compile( - r"""(?i)(?:finnhub)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{20})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/flickr_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/flickr_access_token.py deleted file mode 100644 index 530628547b..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/flickr_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Flickr Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class FlickrAccessTokenDetector(RegexBasedDetector): - """Scans for Flickr Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Flickr Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Flickr Access Token - re.compile( - r"""(?i)(?:flickr)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/flutterwave.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/flutterwave.py deleted file mode 100644 index fc46ba2222..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/flutterwave.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -This plugin searches for Flutterwave API keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class FlutterwaveDetector(RegexBasedDetector): - """Scans for Flutterwave API Keys.""" - - @property - def secret_type(self) -> str: - return "Flutterwave API Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Flutterwave Encryption Key - re.compile(r"""(?i)FLWSECK_TEST-[a-h0-9]{12}"""), - # Flutterwave Public Key - re.compile(r"""(?i)FLWPUBK_TEST-[a-h0-9]{32}-X"""), - # Flutterwave Secret Key - re.compile(r"""(?i)FLWSECK_TEST-[a-h0-9]{32}-X"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/frameio_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/frameio_api_token.py deleted file mode 100644 index 9524e873d4..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/frameio_api_token.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -This plugin searches for Frame.io API tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class FrameIoApiTokenDetector(RegexBasedDetector): - """Scans for Frame.io API Tokens.""" - - @property - def secret_type(self) -> str: - return "Frame.io API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Frame.io API token - re.compile(r"""(?i)fio-u-[a-z0-9\-_=]{64}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/freshbooks_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/freshbooks_access_token.py deleted file mode 100644 index b6b16e2b83..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/freshbooks_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Freshbooks Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class FreshbooksAccessTokenDetector(RegexBasedDetector): - """Scans for Freshbooks Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Freshbooks Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Freshbooks Access Token - re.compile( - r"""(?i)(?:freshbooks)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gcp_api_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gcp_api_key.py deleted file mode 100644 index 6055cc2622..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gcp_api_key.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for GCP API keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class GCPApiKeyDetector(RegexBasedDetector): - """Scans for GCP API keys.""" - - @property - def secret_type(self) -> str: - return "GCP API Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # GCP API Key - re.compile( - r"""(?i)\b(AIza[0-9A-Za-z\\-_]{35})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/github_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/github_token.py deleted file mode 100644 index acb5e3fc76..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/github_token.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -This plugin searches for GitHub tokens -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class GitHubTokenCustomDetector(RegexBasedDetector): - """Scans for GitHub tokens.""" - - @property - def secret_type(self) -> str: - return "GitHub Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # GitHub App/Personal Access/OAuth Access/Refresh Token - # ref. https://github.blog/2021-04-05-behind-githubs-new-authentication-token-formats/ - re.compile(r"(?:ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9_]{36}"), - # GitHub Fine-Grained Personal Access Token - re.compile(r"github_pat_[0-9a-zA-Z_]{82}"), - re.compile(r"gho_[0-9a-zA-Z]{36}"), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gitlab.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gitlab.py deleted file mode 100644 index 2277d8a2d3..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gitlab.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -This plugin searches for GitLab secrets. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class GitLabDetector(RegexBasedDetector): - """Scans for GitLab Secrets.""" - - @property - def secret_type(self) -> str: - return "GitLab Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # GitLab Personal Access Token - re.compile(r"""glpat-[0-9a-zA-Z\-\_]{20}"""), - # GitLab Pipeline Trigger Token - re.compile(r"""glptt-[0-9a-f]{40}"""), - # GitLab Runner Registration Token - re.compile(r"""GR1348941[0-9a-zA-Z\-\_]{20}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gitter_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gitter_access_token.py deleted file mode 100644 index 1febe70cb9..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gitter_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Gitter Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class GitterAccessTokenDetector(RegexBasedDetector): - """Scans for Gitter Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Gitter Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Gitter Access Token - re.compile( - r"""(?i)(?:gitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9_-]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gocardless_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gocardless_api_token.py deleted file mode 100644 index 240f6e4c58..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/gocardless_api_token.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -This plugin searches for GoCardless API tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class GoCardlessApiTokenDetector(RegexBasedDetector): - """Scans for GoCardless API Tokens.""" - - @property - def secret_type(self) -> str: - return "GoCardless API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # GoCardless API token - re.compile( - r"""(?:gocardless)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(live_[a-z0-9\-_=]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""", - re.IGNORECASE, - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/grafana.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/grafana.py deleted file mode 100644 index fd37f0f639..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/grafana.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -This plugin searches for Grafana secrets. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class GrafanaDetector(RegexBasedDetector): - """Scans for Grafana Secrets.""" - - @property - def secret_type(self) -> str: - return "Grafana Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Grafana API key or Grafana Cloud API key - re.compile( - r"""(?i)\b(eyJrIjoi[A-Za-z0-9]{70,400}={0,2})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Grafana Cloud API token - re.compile( - r"""(?i)\b(glc_[A-Za-z0-9+/]{32,400}={0,2})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Grafana Service Account token - re.compile( - r"""(?i)\b(glsa_[A-Za-z0-9]{32}_[A-Fa-f0-9]{8})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/hashicorp_tf_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/hashicorp_tf_api_token.py deleted file mode 100644 index 97013fd846..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/hashicorp_tf_api_token.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -This plugin searches for HashiCorp Terraform user/org API tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class HashiCorpTFApiTokenDetector(RegexBasedDetector): - """Scans for HashiCorp Terraform User/Org API Tokens.""" - - @property - def secret_type(self) -> str: - return "HashiCorp Terraform API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # HashiCorp Terraform user/org API token - re.compile(r"""(?i)[a-z0-9]{14}\.atlasv1\.[a-z0-9\-_=]{60,70}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/heroku_api_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/heroku_api_key.py deleted file mode 100644 index 53be8aa486..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/heroku_api_key.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Heroku API Keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class HerokuApiKeyDetector(RegexBasedDetector): - """Scans for Heroku API Keys.""" - - @property - def secret_type(self) -> str: - return "Heroku API Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:heroku)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/hubspot_api_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/hubspot_api_key.py deleted file mode 100644 index 230ef659ba..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/hubspot_api_key.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for HubSpot API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class HubSpotApiTokenDetector(RegexBasedDetector): - """Scans for HubSpot API Tokens.""" - - @property - def secret_type(self) -> str: - return "HubSpot API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # HubSpot API Token - re.compile( - r"""(?i)(?:hubspot)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9A-F]{8}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/huggingface.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/huggingface.py deleted file mode 100644 index be83a3a0d5..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/huggingface.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -This plugin searches for Hugging Face Access and Organization API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class HuggingFaceDetector(RegexBasedDetector): - """Scans for Hugging Face Tokens.""" - - @property - def secret_type(self) -> str: - return "Hugging Face Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Hugging Face Access token - re.compile(r"""(?:^|[\\'"` >=:])(hf_[a-zA-Z]{34})(?:$|[\\'"` <])"""), - # Hugging Face Organization API token - re.compile( - r"""(?:^|[\\'"` >=:\(,)])(api_org_[a-zA-Z]{34})(?:$|[\\'"` <\),])""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/intercom_api_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/intercom_api_key.py deleted file mode 100644 index 24e16fc73a..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/intercom_api_key.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Intercom API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class IntercomApiTokenDetector(RegexBasedDetector): - """Scans for Intercom API Tokens.""" - - @property - def secret_type(self) -> str: - return "Intercom API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:intercom)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{60})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/jfrog.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/jfrog.py deleted file mode 100644 index 3eabbfe3a4..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/jfrog.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for JFrog-related secrets like API Key and Identity Token. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class JFrogDetector(RegexBasedDetector): - """Scans for JFrog-related secrets.""" - - @property - def secret_type(self) -> str: - return "JFrog Secrets" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # JFrog API Key - re.compile( - r"""(?i)(?:jfrog|artifactory|bintray|xray)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{73})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # JFrog Identity Token - re.compile( - r"""(?i)(?:jfrog|artifactory|bintray|xray)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/jwt.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/jwt.py deleted file mode 100644 index 6658a09502..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/jwt.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Base64-encoded JSON Web Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class JWTBase64Detector(RegexBasedDetector): - """Scans for Base64-encoded JSON Web Tokens.""" - - @property - def secret_type(self) -> str: - return "Base64-encoded JSON Web Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Base64-encoded JSON Web Token - re.compile( - r"""\bZXlK(?:(?PaGJHY2lPaU)|(?PaGNIVWlPaU)|(?PaGNIWWlPaU)|(?PaGRXUWlPaU)|(?PaU5qUWlP)|(?PamNtbDBJanBi)|(?PamRIa2lPaU)|(?PbGNHc2lPbn)|(?PbGJtTWlPaU)|(?PcWEzVWlPaU)|(?PcWQyc2lPb)|(?PcGMzTWlPaU)|(?PcGRpSTZJ)|(?PcmFXUWlP)|(?PclpYbGZiM0J6SWpwY)|(?PcmRIa2lPaUp)|(?PdWIyNWpaU0k2)|(?Pd01tTWlP)|(?Pd01uTWlPaU)|(?Pd2NIUWlPaU)|(?PemRXSWlPaU)|(?PemRuUWlP)|(?PMFlXY2lPaU)|(?PMGVYQWlPaUp)|(?PMWNtd2l)|(?PMWMyVWlPaUp)|(?PMlpYSWlPaU)|(?PMlpYSnphVzl1SWpv)|(?PNElqb2)|(?PNE5XTWlP)|(?PNE5YUWlPaU)|(?PNE5YUWpVekkxTmlJNkl)|(?PNE5YVWlPaU)|(?PNmFYQWlPaU))[a-zA-Z0-9\/\\_+\-\r\n]{40,}={0,2}""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/kraken_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/kraken_access_token.py deleted file mode 100644 index cb7357cfd9..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/kraken_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Kraken Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class KrakenAccessTokenDetector(RegexBasedDetector): - """Scans for Kraken Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Kraken Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Kraken Access Token - re.compile( - r"""(?i)(?:kraken)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9\/=_\+\-]{80,90})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/kucoin.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/kucoin.py deleted file mode 100644 index 02e990bd8b..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/kucoin.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for Kucoin Access Tokens and Secret Keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class KucoinDetector(RegexBasedDetector): - """Scans for Kucoin Access Tokens and Secret Keys.""" - - @property - def secret_type(self) -> str: - return "Kucoin Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Kucoin Access Token - re.compile( - r"""(?i)(?:kucoin)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{24})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Kucoin Secret Key - re.compile( - r"""(?i)(?:kucoin)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/launchdarkly_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/launchdarkly_access_token.py deleted file mode 100644 index 9779909847..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/launchdarkly_access_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Launchdarkly Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class LaunchdarklyAccessTokenDetector(RegexBasedDetector): - """Scans for Launchdarkly Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Launchdarkly Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:launchdarkly)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/linear.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/linear.py deleted file mode 100644 index 1224b5ec46..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/linear.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -This plugin searches for Linear API Tokens and Linear Client Secrets. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class LinearDetector(RegexBasedDetector): - """Scans for Linear secrets.""" - - @property - def secret_type(self) -> str: - return "Linear Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Linear API Token - re.compile(r"""(?i)lin_api_[a-z0-9]{40}"""), - # Linear Client Secret - re.compile( - r"""(?i)(?:linear)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/linkedin.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/linkedin.py deleted file mode 100644 index 53ff0c30aa..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/linkedin.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for LinkedIn Client IDs and LinkedIn Client secrets. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class LinkedInDetector(RegexBasedDetector): - """Scans for LinkedIn secrets.""" - - @property - def secret_type(self) -> str: - return "LinkedIn Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # LinkedIn Client ID - re.compile( - r"""(?i)(?:linkedin|linked-in)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{14})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # LinkedIn Client secret - re.compile( - r"""(?i)(?:linkedin|linked-in)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{16})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/lob.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/lob.py deleted file mode 100644 index 623ac4f1f9..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/lob.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for Lob API secrets and Lob Publishable API keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class LobDetector(RegexBasedDetector): - """Scans for Lob secrets.""" - - @property - def secret_type(self) -> str: - return "Lob Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Lob API Key - re.compile( - r"""(?i)(?:lob)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}((live|test)_[a-f0-9]{35})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Lob Publishable API Key - re.compile( - r"""(?i)(?:lob)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}((test|live)_pub_[a-f0-9]{31})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mailgun.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mailgun.py deleted file mode 100644 index c403d24546..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mailgun.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -This plugin searches for Mailgun API secrets, public validation keys, and webhook signing keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class MailgunDetector(RegexBasedDetector): - """Scans for Mailgun secrets.""" - - @property - def secret_type(self) -> str: - return "Mailgun Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Mailgun Private API Token - re.compile( - r"""(?i)(?:mailgun)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(key-[a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Mailgun Public Validation Key - re.compile( - r"""(?i)(?:mailgun)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(pubkey-[a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Mailgun Webhook Signing Key - re.compile( - r"""(?i)(?:mailgun)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-h0-9]{32}-[a-h0-9]{8}-[a-h0-9]{8})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mapbox_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mapbox_api_token.py deleted file mode 100644 index 0326b7102a..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mapbox_api_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for MapBox API tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class MapBoxApiTokenDetector(RegexBasedDetector): - """Scans for MapBox API tokens.""" - - @property - def secret_type(self) -> str: - return "MapBox API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # MapBox API Token - re.compile( - r"""(?i)(?:mapbox)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(pk\.[a-z0-9]{60}\.[a-z0-9]{22})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mattermost_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mattermost_access_token.py deleted file mode 100644 index d65b0e7554..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/mattermost_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Mattermost Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class MattermostAccessTokenDetector(RegexBasedDetector): - """Scans for Mattermost Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Mattermost Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Mattermost Access Token - re.compile( - r"""(?i)(?:mattermost)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{26})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/messagebird.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/messagebird.py deleted file mode 100644 index 6adc8317a8..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/messagebird.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for MessageBird API tokens and client IDs. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class MessageBirdDetector(RegexBasedDetector): - """Scans for MessageBird secrets.""" - - @property - def secret_type(self) -> str: - return "MessageBird Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # MessageBird API Token - re.compile( - r"""(?i)(?:messagebird|message-bird|message_bird)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{25})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # MessageBird Client ID - re.compile( - r"""(?i)(?:messagebird|message-bird|message_bird)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/microsoft_teams_webhook.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/microsoft_teams_webhook.py deleted file mode 100644 index 298fd81b0a..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/microsoft_teams_webhook.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Microsoft Teams Webhook URLs. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class MicrosoftTeamsWebhookDetector(RegexBasedDetector): - """Scans for Microsoft Teams Webhook URLs.""" - - @property - def secret_type(self) -> str: - return "Microsoft Teams Webhook" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Microsoft Teams Webhook - re.compile( - r"""https:\/\/[a-z0-9]+\.webhook\.office\.com\/webhookb2\/[a-z0-9]{8}-([a-z0-9]{4}-){3}[a-z0-9]{12}@[a-z0-9]{8}-([a-z0-9]{4}-){3}[a-z0-9]{12}\/IncomingWebhook\/[a-z0-9]{32}\/[a-z0-9]{8}-([a-z0-9]{4}-){3}[a-z0-9]{12}""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/netlify_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/netlify_access_token.py deleted file mode 100644 index cc7a575a42..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/netlify_access_token.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -This plugin searches for Netlify Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class NetlifyAccessTokenDetector(RegexBasedDetector): - """Scans for Netlify Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Netlify Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Netlify Access Token - re.compile( - r"""(?i)(?:netlify)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{40,46})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/new_relic.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/new_relic.py deleted file mode 100644 index cef640155c..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/new_relic.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -This plugin searches for New Relic API tokens and keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class NewRelicDetector(RegexBasedDetector): - """Scans for New Relic API tokens and keys.""" - - @property - def secret_type(self) -> str: - return "New Relic API Secrets" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # New Relic ingest browser API token - re.compile( - r"""(?i)(?:new-relic|newrelic|new_relic)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(NRJS-[a-f0-9]{19})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # New Relic user API ID - re.compile( - r"""(?i)(?:new-relic|newrelic|new_relic)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # New Relic user API Key - re.compile( - r"""(?i)(?:new-relic|newrelic|new_relic)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(NRAK-[a-z0-9]{27})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/nytimes_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/nytimes_access_token.py deleted file mode 100644 index 567b885e5a..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/nytimes_access_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for New York Times Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class NYTimesAccessTokenDetector(RegexBasedDetector): - """Scans for New York Times Access Tokens.""" - - @property - def secret_type(self) -> str: - return "New York Times Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:nytimes|new-york-times,|newyorktimes)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/okta_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/okta_access_token.py deleted file mode 100644 index 97109767b0..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/okta_access_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Okta Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class OktaAccessTokenDetector(RegexBasedDetector): - """Scans for Okta Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Okta Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:okta)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{42})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/openai_api_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/openai_api_key.py deleted file mode 100644 index c5d20f7590..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/openai_api_key.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -This plugin searches for OpenAI API Keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class OpenAIApiKeyDetector(RegexBasedDetector): - """Scans for OpenAI API Keys.""" - - @property - def secret_type(self) -> str: - return "Strict OpenAI API Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [re.compile(r"""(sk-[a-zA-Z0-9]{5,})""")] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/planetscale.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/planetscale.py deleted file mode 100644 index 23a53667e3..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/planetscale.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -This plugin searches for PlanetScale API tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class PlanetScaleDetector(RegexBasedDetector): - """Scans for PlanetScale API Tokens.""" - - @property - def secret_type(self) -> str: - return "PlanetScale API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # the PlanetScale API token - re.compile( - r"""(?i)\b(pscale_tkn_[a-z0-9=\-_\.]{32,64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # the PlanetScale OAuth token - re.compile( - r"""(?i)\b(pscale_oauth_[a-z0-9=\-_\.]{32,64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # the PlanetScale password - re.compile( - r"""(?i)\b(pscale_pw_[a-z0-9=\-_\.]{32,64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/postman_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/postman_api_token.py deleted file mode 100644 index 9469e8191c..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/postman_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Postman API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class PostmanApiTokenDetector(RegexBasedDetector): - """Scans for Postman API Tokens.""" - - @property - def secret_type(self) -> str: - return "Postman API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)\b(PMAK-[a-f0-9]{24}-[a-f0-9]{34})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/prefect_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/prefect_api_token.py deleted file mode 100644 index 35cdb71cae..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/prefect_api_token.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -This plugin searches for Prefect API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class PrefectApiTokenDetector(RegexBasedDetector): - """Scans for Prefect API Tokens.""" - - @property - def secret_type(self) -> str: - return "Prefect API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [re.compile(r"""(?i)\b(pnu_[a-z0-9]{36})(?:['|\"|\n|\r|\s|\x60|;]|$)""")] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/pulumi_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/pulumi_api_token.py deleted file mode 100644 index bae4ce211b..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/pulumi_api_token.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -This plugin searches for Pulumi API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class PulumiApiTokenDetector(RegexBasedDetector): - """Scans for Pulumi API Tokens.""" - - @property - def secret_type(self) -> str: - return "Pulumi API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [re.compile(r"""(?i)\b(pul-[a-f0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""")] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/pypi_upload_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/pypi_upload_token.py deleted file mode 100644 index d4cc913857..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/pypi_upload_token.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -This plugin searches for PyPI Upload Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class PyPiUploadTokenDetector(RegexBasedDetector): - """Scans for PyPI Upload Tokens.""" - - @property - def secret_type(self) -> str: - return "PyPI Upload Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [re.compile(r"""pypi-AgEIcHlwaS5vcmc[A-Za-z0-9\-_]{50,1000}""")] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/rapidapi_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/rapidapi_access_token.py deleted file mode 100644 index 18b2346148..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/rapidapi_access_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for RapidAPI Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class RapidApiAccessTokenDetector(RegexBasedDetector): - """Scans for RapidAPI Access Tokens.""" - - @property - def secret_type(self) -> str: - return "RapidAPI Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:rapidapi)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9_-]{50})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/readme_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/readme_api_token.py deleted file mode 100644 index 47bdffb120..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/readme_api_token.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -This plugin searches for Readme API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class ReadmeApiTokenDetector(RegexBasedDetector): - """Scans for Readme API Tokens.""" - - @property - def secret_type(self) -> str: - return "Readme API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile(r"""(?i)\b(rdme_[a-z0-9]{70})(?:['|\"|\n|\r|\s|\x60|;]|$)""") - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/rubygems_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/rubygems_api_token.py deleted file mode 100644 index d49c58e73e..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/rubygems_api_token.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -This plugin searches for Rubygem API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class RubygemsApiTokenDetector(RegexBasedDetector): - """Scans for Rubygem API Tokens.""" - - @property - def secret_type(self) -> str: - return "Rubygem API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile(r"""(?i)\b(rubygems_[a-f0-9]{48})(?:['|\"|\n|\r|\s|\x60|;]|$)""") - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/scalingo_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/scalingo_api_token.py deleted file mode 100644 index 3f8a59ee41..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/scalingo_api_token.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -This plugin searches for Scalingo API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class ScalingoApiTokenDetector(RegexBasedDetector): - """Scans for Scalingo API Tokens.""" - - @property - def secret_type(self) -> str: - return "Scalingo API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [re.compile(r"""\btk-us-[a-zA-Z0-9-_]{48}\b""")] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendbird.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendbird.py deleted file mode 100644 index 4b270d71e5..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendbird.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This plugin searches for Sendbird Access IDs and Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class SendbirdDetector(RegexBasedDetector): - """Scans for Sendbird Access IDs and Tokens.""" - - @property - def secret_type(self) -> str: - return "Sendbird Credential" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Sendbird Access ID - re.compile( - r"""(?i)(?:sendbird)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Sendbird Access Token - re.compile( - r"""(?i)(?:sendbird)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendgrid_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendgrid_api_token.py deleted file mode 100644 index bf974f4fd7..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendgrid_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for SendGrid API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class SendGridApiTokenDetector(RegexBasedDetector): - """Scans for SendGrid API Tokens.""" - - @property - def secret_type(self) -> str: - return "SendGrid API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)\b(SG\.[a-z0-9=_\-\.]{66})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendinblue_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendinblue_api_token.py deleted file mode 100644 index a6ed8c15ee..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sendinblue_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for SendinBlue API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class SendinBlueApiTokenDetector(RegexBasedDetector): - """Scans for SendinBlue API Tokens.""" - - @property - def secret_type(self) -> str: - return "SendinBlue API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)\b(xkeysib-[a-f0-9]{64}-[a-z0-9]{16})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sentry_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sentry_access_token.py deleted file mode 100644 index 181fad2c7f..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sentry_access_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Sentry Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class SentryAccessTokenDetector(RegexBasedDetector): - """Scans for Sentry Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Sentry Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:sentry)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/shippo_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/shippo_api_token.py deleted file mode 100644 index 4314c68768..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/shippo_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Shippo API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class ShippoApiTokenDetector(RegexBasedDetector): - """Scans for Shippo API Tokens.""" - - @property - def secret_type(self) -> str: - return "Shippo API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)\b(shippo_(live|test)_[a-f0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/shopify.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/shopify.py deleted file mode 100644 index f5f97c4478..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/shopify.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -This plugin searches for Shopify Access Tokens, Custom Access Tokens, -Private App Access Tokens, and Shared Secrets. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class ShopifyDetector(RegexBasedDetector): - """Scans for Shopify Access Tokens, Custom Access Tokens, Private App Access Tokens, - and Shared Secrets. - """ - - @property - def secret_type(self) -> str: - return "Shopify Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Shopify access token - re.compile(r"""shpat_[a-fA-F0-9]{32}"""), - # Shopify custom access token - re.compile(r"""shpca_[a-fA-F0-9]{32}"""), - # Shopify private app access token - re.compile(r"""shppa_[a-fA-F0-9]{32}"""), - # Shopify shared secret - re.compile(r"""shpss_[a-fA-F0-9]{32}"""), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/slack.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/slack.py deleted file mode 100644 index 4896fd76b2..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/slack.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -This plugin searches for Slack tokens and webhooks. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class SlackDetector(RegexBasedDetector): - """Scans for Slack tokens and webhooks.""" - - @property - def secret_type(self) -> str: - return "Slack Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Slack App-level token - re.compile(r"""(?i)(xapp-\d-[A-Z0-9]+-\d+-[a-z0-9]+)"""), - # Slack Bot token - re.compile(r"""(xoxb-[0-9]{10,13}\-[0-9]{10,13}[a-zA-Z0-9-]*)"""), - # Slack Configuration access token and refresh token - re.compile(r"""(?i)(xoxe.xox[bp]-\d-[A-Z0-9]{163,166})"""), - re.compile(r"""(?i)(xoxe-\d-[A-Z0-9]{146})"""), - # Slack Legacy bot token and token - re.compile(r"""(xoxb-[0-9]{8,14}\-[a-zA-Z0-9]{18,26})"""), - re.compile(r"""(xox[os]-\d+-\d+-\d+-[a-fA-F\d]+)"""), - # Slack Legacy Workspace token - re.compile(r"""(xox[ar]-(?:\d-)?[0-9a-zA-Z]{8,48})"""), - # Slack User token and enterprise token - re.compile(r"""(xox[pe](?:-[0-9]{10,13}){3}-[a-zA-Z0-9-]{28,34})"""), - # Slack Webhook URL - re.compile( - r"""(https?:\/\/)?hooks.slack.com\/(services|workflows)\/[A-Za-z0-9+\/]{43,46}""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/snyk_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/snyk_api_token.py deleted file mode 100644 index 839bb57317..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/snyk_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Snyk API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class SnykApiTokenDetector(RegexBasedDetector): - """Scans for Snyk API Tokens.""" - - @property - def secret_type(self) -> str: - return "Snyk API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:snyk)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/squarespace_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/squarespace_access_token.py deleted file mode 100644 index 0dc83ad91d..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/squarespace_access_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Squarespace Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class SquarespaceAccessTokenDetector(RegexBasedDetector): - """Scans for Squarespace Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Squarespace Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:squarespace)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sumologic.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sumologic.py deleted file mode 100644 index 7117629acc..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/sumologic.py +++ /dev/null @@ -1,22 +0,0 @@ -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class SumoLogicDetector(RegexBasedDetector): - """Scans for SumoLogic Access ID and Access Token.""" - - @property - def secret_type(self) -> str: - return "SumoLogic" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i:(?:sumo)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3})(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(su[a-zA-Z0-9]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - re.compile( - r"""(?i)(?:sumo)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/telegram_bot_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/telegram_bot_api_token.py deleted file mode 100644 index 30854fda1d..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/telegram_bot_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Telegram Bot API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class TelegramBotApiTokenDetector(RegexBasedDetector): - """Scans for Telegram Bot API Tokens.""" - - @property - def secret_type(self) -> str: - return "Telegram Bot API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:^|[^0-9])([0-9]{5,16}:A[a-zA-Z0-9_\-]{34})(?:$|[^a-zA-Z0-9_\-])""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/travisci_access_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/travisci_access_token.py deleted file mode 100644 index 90f9b48f46..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/travisci_access_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Travis CI Access Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class TravisCiAccessTokenDetector(RegexBasedDetector): - """Scans for Travis CI Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Travis CI Access Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:travis)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{22})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/twitch_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/twitch_api_token.py deleted file mode 100644 index 1e0e3ccf8f..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/twitch_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Twitch API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class TwitchApiTokenDetector(RegexBasedDetector): - """Scans for Twitch API Tokens.""" - - @property - def secret_type(self) -> str: - return "Twitch API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:twitch)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{30})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/twitter.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/twitter.py deleted file mode 100644 index 99ad170d1e..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/twitter.py +++ /dev/null @@ -1,36 +0,0 @@ -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class TwitterDetector(RegexBasedDetector): - """Scans for Twitter Access Secrets, Access Tokens, API Keys, API Secrets, and Bearer Tokens.""" - - @property - def secret_type(self) -> str: - return "Twitter Secret" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Twitter Access Secret - re.compile( - r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{45})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Twitter Access Token - re.compile( - r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9]{15,25}-[a-zA-Z0-9]{20,40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Twitter API Key - re.compile( - r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{25})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Twitter API Secret - re.compile( - r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{50})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Twitter Bearer Token - re.compile( - r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(A{22}[a-zA-Z0-9%]{80,100})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/typeform_api_token.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/typeform_api_token.py deleted file mode 100644 index 8d9dc0e875..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/typeform_api_token.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Typeform API Tokens. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class TypeformApiTokenDetector(RegexBasedDetector): - """Scans for Typeform API Tokens.""" - - @property - def secret_type(self) -> str: - return "Typeform API Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:typeform)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(tfp_[a-z0-9\-_\.=]{59})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/vault.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/vault.py deleted file mode 100644 index 5ca552cd9e..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/vault.py +++ /dev/null @@ -1,24 +0,0 @@ -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class VaultDetector(RegexBasedDetector): - """Scans for Vault Batch Tokens and Vault Service Tokens.""" - - @property - def secret_type(self) -> str: - return "Vault Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Vault Batch Token - re.compile( - r"""(?i)\b(hvb\.[a-z0-9_-]{138,212})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Vault Service Token - re.compile( - r"""(?i)\b(hvs\.[a-z0-9_-]{90,100})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/yandex.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/yandex.py deleted file mode 100644 index a58faec0d1..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/yandex.py +++ /dev/null @@ -1,28 +0,0 @@ -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class YandexDetector(RegexBasedDetector): - """Scans for Yandex Access Tokens, API Keys, and AWS Access Tokens.""" - - @property - def secret_type(self) -> str: - return "Yandex Token" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - # Yandex Access Token - re.compile( - r"""(?i)(?:yandex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(t1\.[A-Z0-9a-z_-]+[=]{0,2}\.[A-Z0-9a-z_-]{86}[=]{0,2})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Yandex API Key - re.compile( - r"""(?i)(?:yandex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(AQVN[A-Za-z0-9_\-]{35,38})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - # Yandex AWS Access Token - re.compile( - r"""(?i)(?:yandex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(YC[a-zA-Z0-9_\-]{38})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ), - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/zendesk_secret_key.py b/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/zendesk_secret_key.py deleted file mode 100644 index 42c087c5b6..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/secrets_plugins/zendesk_secret_key.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -This plugin searches for Zendesk Secret Keys. -""" - -import re - -from detect_secrets.plugins.base import RegexBasedDetector - - -class ZendeskSecretKeyDetector(RegexBasedDetector): - """Scans for Zendesk Secret Keys.""" - - @property - def secret_type(self) -> str: - return "Zendesk Secret Key" - - @property - def denylist(self) -> list[re.Pattern]: - return [ - re.compile( - r"""(?i)(?:zendesk)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" - ) - ] diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py deleted file mode 100644 index 89c3b85468..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py +++ /dev/null @@ -1,926 +0,0 @@ -""" -Base class for sending emails to user after creating keys or invite links - -""" - -import html -import json -import os -from typing import List, Literal, Optional - -from litellm_enterprise.types.enterprise_callbacks.send_emails import ( - EmailEvent, - EmailParams, - SendKeyCreatedEmailEvent, - SendKeyRotatedEmailEvent, -) - -from litellm._logging import verbose_proxy_logger -from litellm.caching.caching import DualCache -from litellm.constants import ( - EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE, - EMAIL_BUDGET_ALERT_TTL, -) -from litellm.integrations.custom_logger import CustomLogger -from litellm.integrations.email_templates.email_footer import EMAIL_FOOTER -from litellm.integrations.email_templates.key_created_email import ( - KEY_CREATED_EMAIL_TEMPLATE, -) -from litellm.integrations.email_templates.key_rotated_email import ( - KEY_ROTATED_EMAIL_TEMPLATE, -) -from litellm.integrations.email_templates.templates import ( - MAX_BUDGET_ALERT_EMAIL_TEMPLATE, - SOFT_BUDGET_ALERT_EMAIL_TEMPLATE, - TEAM_SOFT_BUDGET_ALERT_EMAIL_TEMPLATE, -) -from litellm.integrations.email_templates.user_invitation_email import ( - USER_INVITATION_EMAIL_TEMPLATE, -) -from litellm.proxy._types import ( - CallInfo, - InvitationNew, - Litellm_EntityType, - UserAPIKeyAuth, - WebhookEvent, -) -from litellm.secret_managers.main import get_secret_bool -from litellm.types.integrations.slack_alerting import LITELLM_LOGO_URL - - -def _parse_email_list(raw) -> List[str]: - """Parse emails from a list or comma-separated string.""" - if isinstance(raw, list): - return [e.strip() for e in raw if isinstance(e, str) and e.strip()] - elif isinstance(raw, str): - return [e.strip() for e in raw.split(",") if e.strip()] - return [] - - -class BaseEmailLogger(CustomLogger): - DEFAULT_LITELLM_EMAIL = "notifications@alerts.litellm.ai" - DEFAULT_SUPPORT_EMAIL = "support@berri.ai" - DEFAULT_SUBJECT_TEMPLATES = { - EmailEvent.new_user_invitation: "LiteLLM: {event_message}", - EmailEvent.virtual_key_created: "LiteLLM: {event_message}", - EmailEvent.virtual_key_rotated: "LiteLLM: {event_message}", - } - - def __init__( - self, - internal_usage_cache: Optional[DualCache] = None, - **kwargs, - ): - """ - Initialize BaseEmailLogger - - Args: - internal_usage_cache: DualCache instance for preventing duplicate alerts - **kwargs: Additional arguments passed to CustomLogger - """ - super().__init__(**kwargs) - self.internal_usage_cache = internal_usage_cache or DualCache() - - async def send_user_invitation_email(self, event: WebhookEvent): - """ - Send email to user after inviting them to the team - """ - email_params = await self._get_email_params( - email_event=EmailEvent.new_user_invitation, - user_id=event.user_id, - user_email=getattr(event, "user_email", None), - event_message=event.event_message, - ) - - verbose_proxy_logger.debug( - f"send_user_invitation_email_event: {json.dumps(event, indent=4, default=str)}" - ) - - email_html_content = USER_INVITATION_EMAIL_TEMPLATE.format( - email_logo_url=email_params.logo_url, - recipient_email=email_params.recipient_email, - base_url=email_params.base_url, - email_support_contact=email_params.support_contact, - email_footer=email_params.signature, - ) - - await self.send_email( - from_email=self.DEFAULT_LITELLM_EMAIL, - to_email=[email_params.recipient_email], - subject=email_params.subject, - html_body=email_html_content, - ) - - pass - - async def send_key_created_email( - self, send_key_created_email_event: SendKeyCreatedEmailEvent - ): - """ - Send email to user after creating key for the user - """ - email_params = await self._get_email_params( - user_id=send_key_created_email_event.user_id, - user_email=send_key_created_email_event.user_email, - email_event=EmailEvent.virtual_key_created, - event_message=send_key_created_email_event.event_message, - ) - - verbose_proxy_logger.debug( - f"send_key_created_email_event: {json.dumps(send_key_created_email_event, indent=4, default=str)}" - ) - - # Check if API key should be included in email - include_api_key = get_secret_bool( - secret_name="EMAIL_INCLUDE_API_KEY", default_value=True - ) - if include_api_key is None: - include_api_key = True # Default to True if not set - key_token_display = ( - send_key_created_email_event.virtual_key - if include_api_key - else "[Key hidden for security - retrieve from dashboard]" - ) - - email_html_content = KEY_CREATED_EMAIL_TEMPLATE.format( - email_logo_url=email_params.logo_url, - recipient_email=email_params.recipient_email, - key_budget=self._format_key_budget(send_key_created_email_event.max_budget), - key_token=key_token_display, - base_url=email_params.base_url, - email_support_contact=email_params.support_contact, - email_footer=email_params.signature, - ) - - await self.send_email( - from_email=self.DEFAULT_LITELLM_EMAIL, - to_email=[email_params.recipient_email], - subject=email_params.subject, - html_body=email_html_content, - ) - pass - - async def send_key_rotated_email( - self, send_key_rotated_email_event: SendKeyRotatedEmailEvent - ): - """ - Send email to user after rotating key for the user - """ - email_params = await self._get_email_params( - user_id=send_key_rotated_email_event.user_id, - user_email=send_key_rotated_email_event.user_email, - email_event=EmailEvent.virtual_key_rotated, - event_message=send_key_rotated_email_event.event_message, - ) - - verbose_proxy_logger.debug( - f"send_key_rotated_email_event: {json.dumps(send_key_rotated_email_event, indent=4, default=str)}" - ) - - # Check if API key should be included in email - include_api_key = get_secret_bool( - secret_name="EMAIL_INCLUDE_API_KEY", default_value=True - ) - if include_api_key is None: - include_api_key = True # Default to True if not set - key_token_display = ( - send_key_rotated_email_event.virtual_key - if include_api_key - else "[Key hidden for security - retrieve from dashboard]" - ) - - email_html_content = KEY_ROTATED_EMAIL_TEMPLATE.format( - email_logo_url=email_params.logo_url, - recipient_email=email_params.recipient_email, - key_budget=self._format_key_budget(send_key_rotated_email_event.max_budget), - key_token=key_token_display, - base_url=email_params.base_url, - email_support_contact=email_params.support_contact, - email_footer=email_params.signature, - ) - - await self.send_email( - from_email=self.DEFAULT_LITELLM_EMAIL, - to_email=[email_params.recipient_email], - subject=email_params.subject, - html_body=email_html_content, - ) - pass - - async def send_soft_budget_alert_email(self, event: WebhookEvent): - """ - Send email to user when soft budget is crossed - """ - email_params = await self._get_email_params( - email_event=EmailEvent.soft_budget_crossed, # Reuse existing event type for subject template - user_id=event.user_id, - user_email=event.user_email, - event_message=event.event_message, - ) - - verbose_proxy_logger.debug( - f"send_soft_budget_alert_email_event: {json.dumps(event.model_dump(exclude_none=True), indent=4, default=str)}" - ) - - # Format budget values - soft_budget_str = ( - f"${event.soft_budget}" if event.soft_budget is not None else "N/A" - ) - spend_str = f"${event.spend}" if event.spend is not None else "$0.00" - max_budget_info = "" - if event.max_budget is not None: - max_budget_info = f"Maximum Budget: ${event.max_budget}
" - - email_html_content = SOFT_BUDGET_ALERT_EMAIL_TEMPLATE.format( - email_logo_url=email_params.logo_url, - recipient_email=email_params.recipient_email, - soft_budget=soft_budget_str, - spend=spend_str, - max_budget_info=max_budget_info, - base_url=email_params.base_url, - email_support_contact=email_params.support_contact, - ) - await self.send_email( - from_email=self.DEFAULT_LITELLM_EMAIL, - to_email=[email_params.recipient_email], - subject=email_params.subject, - html_body=email_html_content, - ) - pass - - async def send_team_soft_budget_alert_email(self, event: WebhookEvent): - """ - Send email to team members when team soft budget is crossed - Supports multiple recipients via alert_emails field from team metadata - """ - # Collect all recipient emails - recipient_emails: List[str] = [] - - # Add additional alert emails from team metadata.soft_budget_alert_emails - if hasattr(event, "alert_emails") and event.alert_emails: - for email in event.alert_emails: - if email and email not in recipient_emails: # Avoid duplicates - recipient_emails.append(email) - - # If no recipients found, skip sending - if not recipient_emails: - verbose_proxy_logger.warning( - f"No recipient emails found for team soft budget alert. event={event.model_dump(exclude_none=True)}" - ) - return - - # Validate that we have at least one valid email address - first_recipient_email = recipient_emails[0] - if not first_recipient_email or not first_recipient_email.strip(): - verbose_proxy_logger.warning( - f"Invalid recipient email found for team soft budget alert. event={event.model_dump(exclude_none=True)}" - ) - return - - verbose_proxy_logger.debug( - f"send_team_soft_budget_alert_email_event: {json.dumps(event.model_dump(exclude_none=True), indent=4, default=str)}" - ) - - # Get email params using the first recipient email (for template formatting) - # For team alerts with alert_emails, we don't need user_id lookup since we already have email addresses - # Pass user_id=None to prevent _get_email_params from trying to look up email from a potentially None user_id - email_params = await self._get_email_params( - email_event=EmailEvent.soft_budget_crossed, - user_id=None, # Team alerts don't require user_id when alert_emails are provided - user_email=first_recipient_email, - event_message=event.event_message, - ) - - # Format budget values - soft_budget_str = ( - f"${event.soft_budget}" if event.soft_budget is not None else "N/A" - ) - spend_str = f"${event.spend}" if event.spend is not None else "$0.00" - max_budget_info = "" - if event.max_budget is not None: - max_budget_info = f"Maximum Budget: ${event.max_budget}
" - - # Use team alias or generic greeting - team_alias = event.team_alias or "Team" - - email_html_content = TEAM_SOFT_BUDGET_ALERT_EMAIL_TEMPLATE.format( - email_logo_url=email_params.logo_url, - team_alias=team_alias, - soft_budget=soft_budget_str, - spend=spend_str, - max_budget_info=max_budget_info, - base_url=email_params.base_url, - email_support_contact=email_params.support_contact, - ) - - # Send email to all recipients - await self.send_email( - from_email=self.DEFAULT_LITELLM_EMAIL, - to_email=recipient_emails, - subject=email_params.subject, - html_body=email_html_content, - ) - pass - - async def send_max_budget_alert_email( - self, - event: WebhookEvent, - threshold_pct: Optional[int] = None, - recipient_emails: Optional[List[str]] = None, - ): - """ - Send email to user when max budget alert threshold is reached. - - Args: - event: The webhook event with spend/budget info - threshold_pct: Override percentage for multi-threshold alerts (e.g. 50, 75, 100). - When None, uses EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE (old behavior). - recipient_emails: Override recipient list for multi-threshold alerts. - When None, resolves single owner email via _get_email_params (old behavior). - """ - verbose_proxy_logger.debug( - f"send_max_budget_alert_email_event: {json.dumps(event.model_dump(exclude_none=True), indent=4, default=str)}" - ) - - # Format budget values - spend_str = f"${event.spend}" if event.spend is not None else "$0.00" - max_budget_str = ( - f"${event.max_budget}" if event.max_budget is not None else "N/A" - ) - - # Calculate percentage and alert threshold - percentage = threshold_pct if threshold_pct is not None else int( - EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100 - ) - threshold_fraction = percentage / 100.0 - alert_threshold_str = ( - f"${event.max_budget * threshold_fraction:.2f}" - if event.max_budget is not None - else "N/A" - ) - - if recipient_emails: - # Multi-threshold path: batch send with generic key-based greeting - email_params = await self._get_email_params( - email_event=EmailEvent.max_budget_alert, - user_id=event.user_id, - user_email=event.user_email or recipient_emails[0], - event_message=event.event_message, - ) - greeting = html.escape( - event.user_email or event.key_alias or event.token or "" - ) - email_html_content = MAX_BUDGET_ALERT_EMAIL_TEMPLATE.format( - email_logo_url=email_params.logo_url, - recipient_email=greeting, - percentage=percentage, - spend=spend_str, - max_budget=max_budget_str, - alert_threshold=alert_threshold_str, - base_url=email_params.base_url, - email_support_contact=email_params.support_contact, - ) - await self.send_email( - from_email=self.DEFAULT_LITELLM_EMAIL, - to_email=recipient_emails, - subject=email_params.subject, - html_body=email_html_content, - ) - else: - # Old path: single recipient resolved from user_id/user_email - email_params = await self._get_email_params( - email_event=EmailEvent.max_budget_alert, - user_id=event.user_id, - user_email=event.user_email, - event_message=event.event_message, - ) - email_html_content = MAX_BUDGET_ALERT_EMAIL_TEMPLATE.format( - email_logo_url=email_params.logo_url, - recipient_email=email_params.recipient_email, - percentage=percentage, - spend=spend_str, - max_budget=max_budget_str, - alert_threshold=alert_threshold_str, - base_url=email_params.base_url, - email_support_contact=email_params.support_contact, - ) - await self.send_email( - from_email=self.DEFAULT_LITELLM_EMAIL, - to_email=[email_params.recipient_email], - subject=email_params.subject, - html_body=email_html_content, - ) - - async def budget_alerts( - self, - type: Literal[ - "token_budget", - "soft_budget", - "max_budget_alert", - "user_budget", - "team_budget", - "organization_budget", - "proxy_budget", - "projected_limit_exceeded", - ], - user_info: CallInfo, - ): - """ - Send a budget alert via email - - Args: - type: The type of budget alert to send - user_info: The user info to send the alert for - """ - ## PREVENTITIVE ALERTING ## - # - Alert once within 24hr period - # - Cache this information - # - Don't re-alert, if alert already sent - _cache: DualCache = self.internal_usage_cache - - # For soft_budget alerts, check if we've already sent an alert - if type == "soft_budget": - # For team soft budget alerts, we only need team soft_budget to be set - # For other entity types, we need either max_budget or soft_budget - if user_info.event_group == Litellm_EntityType.TEAM: - if user_info.soft_budget is None: - return - # For team soft budget alerts, require alert_emails to be configured - # Team soft budget alerts are sent via metadata.soft_budget_alerting_emails - if user_info.alert_emails is None or len(user_info.alert_emails) == 0: - verbose_proxy_logger.debug( - "Skipping team soft budget email alert: no alert_emails configured", - ) - return - else: - # For non-team alerts, require either max_budget or soft_budget - if user_info.max_budget is None and user_info.soft_budget is None: - return - if ( - user_info.soft_budget is not None - and user_info.spend >= user_info.soft_budget - ): - # Generate cache key based on event type and identifier - # Use appropriate ID based on event_group to ensure unique cache keys per entity type - if user_info.event_group == Litellm_EntityType.TEAM: - _id = user_info.team_id or "default_id" - elif user_info.event_group == Litellm_EntityType.ORGANIZATION: - _id = user_info.organization_id or "default_id" - elif user_info.event_group == Litellm_EntityType.USER: - _id = user_info.user_id or "default_id" - else: - # For KEY and other types, use token or user_id - _id = user_info.token or user_info.user_id or "default_id" - _cache_key = f"email_budget_alerts:soft_budget_crossed:{_id}" - - # Check if we've already sent this alert - result = await _cache.async_get_cache(key=_cache_key) - if result is None: - # Create WebhookEvent for soft budget alert - event_message = f"Soft Budget Crossed - Total Soft Budget: ${user_info.soft_budget}" - webhook_event = WebhookEvent( - event="soft_budget_crossed", - event_message=event_message, - spend=user_info.spend, - max_budget=user_info.max_budget, - soft_budget=user_info.soft_budget, - token=user_info.token, - customer_id=user_info.customer_id, - user_id=user_info.user_id, - team_id=user_info.team_id, - team_alias=user_info.team_alias, - organization_id=user_info.organization_id, - user_email=user_info.user_email, - key_alias=user_info.key_alias, - projected_exceeded_date=user_info.projected_exceeded_date, - projected_spend=user_info.projected_spend, - event_group=user_info.event_group, - alert_emails=user_info.alert_emails, - ) - - try: - # Use team-specific function for team alerts, otherwise use standard function - if user_info.event_group == Litellm_EntityType.TEAM: - await self.send_team_soft_budget_alert_email(webhook_event) - else: - await self.send_soft_budget_alert_email(webhook_event) - - # Cache the alert to prevent duplicate sends - await _cache.async_set_cache( - key=_cache_key, - value="SENT", - ttl=EMAIL_BUDGET_ALERT_TTL, - ) - except Exception as e: - verbose_proxy_logger.error( - f"Error sending soft budget alert email: {e}", - exc_info=True, - ) - return - - # For max_budget_alert, check if we've already sent an alert - if type == "max_budget_alert": - if user_info.max_budget is not None and user_info.spend is not None: - if user_info.max_budget_alert_emails: - # New path: multi-threshold alerts - await self._handle_multi_threshold_max_budget_alert( - user_info=user_info, _cache=_cache - ) - return - - alert_threshold = ( - user_info.max_budget * EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE - ) - - # Only alert if we've crossed the threshold but haven't exceeded max_budget yet - if ( - user_info.spend >= alert_threshold - and user_info.spend < user_info.max_budget - ): - # Generate cache key based on event type and identifier - _id = user_info.token or user_info.user_id or "default_id" - _cache_key = f"email_budget_alerts:max_budget_alert:{_id}" - - # Check if we've already sent this alert - result = await _cache.async_get_cache(key=_cache_key) - if result is None: - # Calculate percentage - percentage = int( - EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE * 100 - ) - - # Create WebhookEvent for max budget alert - event_message = f"Max Budget Alert - {percentage}% of Maximum Budget Reached" - webhook_event = WebhookEvent( - event="max_budget_alert", - event_message=event_message, - spend=user_info.spend, - max_budget=user_info.max_budget, - soft_budget=user_info.soft_budget, - token=user_info.token, - customer_id=user_info.customer_id, - user_id=user_info.user_id, - team_id=user_info.team_id, - team_alias=user_info.team_alias, - organization_id=user_info.organization_id, - user_email=user_info.user_email, - key_alias=user_info.key_alias, - projected_exceeded_date=user_info.projected_exceeded_date, - projected_spend=user_info.projected_spend, - event_group=user_info.event_group, - ) - - try: - await self.send_max_budget_alert_email(webhook_event) - - # Cache the alert to prevent duplicate sends - await _cache.async_set_cache( - key=_cache_key, - value="SENT", - ttl=EMAIL_BUDGET_ALERT_TTL, - ) - except Exception as e: - verbose_proxy_logger.error( - f"Error sending max budget alert email: {e}", - exc_info=True, - ) - return - - async def _handle_multi_threshold_max_budget_alert( - self, - user_info: CallInfo, - _cache: DualCache, - ): - """ - Loop over configured thresholds in max_budget_alert_emails, - check cache per threshold, and send to configured recipients. - """ - if not user_info.max_budget_alert_emails or user_info.max_budget is None: - return - - for threshold_str, raw_emails in user_info.max_budget_alert_emails.items(): - try: - threshold_pct = int(threshold_str) - except (ValueError, TypeError): - continue - - threshold_amount = user_info.max_budget * (threshold_pct / 100.0) - if user_info.spend < threshold_amount: - continue - - _id = user_info.token or user_info.user_id or "default_id" - _cache_key = ( - f"email_budget_alerts:max_budget_alert:{threshold_pct}:{_id}" - ) - - result = await _cache.async_get_cache(key=_cache_key) - if result is not None: - continue - - # Parse emails + auto-include owner - emails = _parse_email_list(raw_emails) - if user_info.user_email: - emails.append(user_info.user_email) - if not emails: - verbose_proxy_logger.warning( - "No recipients for %d%% threshold on key %s, skipping alert", - threshold_pct, - _id, - ) - continue - recipient_emails = list(set(emails)) - - event_message = f"Max Budget Alert - {threshold_pct}% of Maximum Budget Reached" - webhook_event = WebhookEvent( - event="max_budget_alert", - event_message=event_message, - spend=user_info.spend, - max_budget=user_info.max_budget, - soft_budget=user_info.soft_budget, - token=user_info.token, - customer_id=user_info.customer_id, - user_id=user_info.user_id, - team_id=user_info.team_id, - team_alias=user_info.team_alias, - organization_id=user_info.organization_id, - user_email=user_info.user_email, - key_alias=user_info.key_alias, - projected_exceeded_date=user_info.projected_exceeded_date, - projected_spend=user_info.projected_spend, - event_group=user_info.event_group, - ) - - try: - await self.send_max_budget_alert_email( - webhook_event, - threshold_pct=threshold_pct, - recipient_emails=recipient_emails, - ) - await _cache.async_set_cache( - key=_cache_key, - value="SENT", - ttl=EMAIL_BUDGET_ALERT_TTL, - ) - except Exception as e: - verbose_proxy_logger.error( - f"Error sending multi-threshold max budget alert email for {threshold_pct}%: {e}", - exc_info=True, - ) - - async def _get_email_params( - self, - email_event: EmailEvent, - user_id: Optional[str] = None, - user_email: Optional[str] = None, - event_message: Optional[str] = None, - ) -> EmailParams: - """ - Get common email parameters used across different email sending methods - - Args: - email_event: Type of email event - user_id: Optional user ID to look up email - user_email: Optional direct email address - event_message: Optional message to include in email subject - - Returns: - EmailParams object containing logo_url, support_contact, base_url, recipient_email, subject, and signature - """ - # Get email parameters with premium check for custom values - custom_logo = os.getenv("EMAIL_LOGO_URL", None) - custom_support = os.getenv("EMAIL_SUPPORT_CONTACT", None) - custom_signature = os.getenv("EMAIL_SIGNATURE", None) - custom_subject_invitation = os.getenv("EMAIL_SUBJECT_INVITATION", None) - custom_subject_key_created = os.getenv("EMAIL_SUBJECT_KEY_CREATED", None) - - # Track which custom values were not applied - unused_custom_fields = [] - - # Function to safely get custom value or default - def get_custom_or_default( - custom_value: Optional[str], default_value: str, field_name: str - ) -> str: - if ( - custom_value is not None - ): # Only check premium if trying to use custom value - from litellm.proxy.proxy_server import premium_user - - if premium_user is not True: - unused_custom_fields.append(field_name) - return default_value - return custom_value - return default_value - - # Get parameters, falling back to defaults if custom values aren't allowed - logo_url = get_custom_or_default(custom_logo, LITELLM_LOGO_URL, "logo URL") - support_contact = get_custom_or_default( - custom_support, self.DEFAULT_SUPPORT_EMAIL, "support contact" - ) - base_url = os.getenv( - "PROXY_BASE_URL", "http://0.0.0.0:4000" - ) # Not a premium feature - signature = get_custom_or_default( - custom_signature, EMAIL_FOOTER, "email signature" - ) - - # Get custom subject template based on email event type - if email_event == EmailEvent.new_user_invitation: - subject_template = get_custom_or_default( - custom_subject_invitation, - self.DEFAULT_SUBJECT_TEMPLATES[EmailEvent.new_user_invitation], - "invitation subject template", - ) - elif email_event == EmailEvent.virtual_key_created: - subject_template = get_custom_or_default( - custom_subject_key_created, - self.DEFAULT_SUBJECT_TEMPLATES[EmailEvent.virtual_key_created], - "key created subject template", - ) - elif email_event == EmailEvent.virtual_key_rotated: - custom_subject_key_rotated = os.getenv("EMAIL_SUBJECT_KEY_ROTATED", None) - subject_template = get_custom_or_default( - custom_subject_key_rotated, - self.DEFAULT_SUBJECT_TEMPLATES[EmailEvent.virtual_key_rotated], - "key rotated subject template", - ) - else: - subject_template = "LiteLLM: {event_message}" - - subject = ( - subject_template.format(event_message=event_message) - if event_message - else "LiteLLM Notification" - ) - - recipient_email: Optional[str] = ( - user_email or await self._lookup_user_email_from_db(user_id=user_id) - ) - if recipient_email is None: - raise ValueError( - f"User email not found for user_id: {user_id}. User email is required to send email." - ) - - # if user invited event then send invitation link - if email_event == EmailEvent.new_user_invitation: - base_url = await self._get_invitation_link( - user_id=user_id, base_url=base_url - ) - - # If any custom fields were not applied, log a warning - if unused_custom_fields: - fields_str = ", ".join(unused_custom_fields) - warning_msg = ( - f"Email sent with default values instead of custom values for: {fields_str}. " - "This is an Enterprise feature. To use custom email fields, please upgrade to LiteLLM Enterprise. " - "Schedule a meeting here: https://enterprise.litellm.ai/demo" - ) - verbose_proxy_logger.warning(f"{warning_msg}") - - return EmailParams( - logo_url=logo_url, - support_contact=support_contact, - base_url=base_url, - recipient_email=recipient_email, - subject=subject, - signature=signature, - ) - - def _format_key_budget(self, max_budget: Optional[float]) -> str: - """ - Format the key budget to be displayed in the email - """ - if max_budget is None: - return "No budget" - return f"${max_budget}" - - async def _lookup_user_email_from_db(self, user_id: Optional[str]) -> Optional[str]: - """ - Lookup user email from user_id - """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - verbose_proxy_logger.debug( - f"Prisma client not found. Unable to lookup user email for user_id: {user_id}" - ) - return None - - user_row = await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_id} - ) - - if user_row is not None: - return user_row.user_email - return None - - async def _get_invitation_link(self, user_id: Optional[str], base_url: str) -> str: - """ - Get invitation link for the user - """ - # Early validation - if not user_id: - verbose_proxy_logger.debug("No user_id provided for invitation link") - return base_url - - if not await self._is_prisma_client_available(): - return base_url - - # Wait for any concurrent invitation creation to complete - await self._wait_for_invitation_creation() - - # Get or create invitation - invitation = await self._get_or_create_invitation(user_id) - if not invitation: - verbose_proxy_logger.warning( - f"Failed to get/create invitation for user_id: {user_id}" - ) - return base_url - - return self._construct_invitation_link(invitation.id, base_url) - - async def _is_prisma_client_available(self) -> bool: - """Check if Prisma client is available""" - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - verbose_proxy_logger.debug( - "Prisma client not found. Unable to lookup invitation" - ) - return False - return True - - async def _wait_for_invitation_creation(self) -> None: - """ - Wait for any concurrent invitation creation to complete. - - The UI calls /invitation/new to generate the invitation link. - We wait to ensure any pending invitation creation is completed. - """ - import asyncio - - await asyncio.sleep(10) - - async def _get_or_create_invitation(self, user_id: str): - """ - Get existing invitation or create a new one for the user - - Returns: - Invitation object with id attribute, or None if failed - """ - from litellm.proxy.management_helpers.user_invitation import ( - create_invitation_for_user, - ) - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - verbose_proxy_logger.error( - "Prisma client is None in _get_or_create_invitation" - ) - return None - - try: - # Try to get existing invitation - existing_invitations = ( - await prisma_client.db.litellm_invitationlink.find_many( - where={"user_id": user_id}, - order={"created_at": "desc"}, - ) - ) - - if existing_invitations and len(existing_invitations) > 0: - verbose_proxy_logger.debug( - f"Found existing invitation for user_id: {user_id}" - ) - return existing_invitations[0] - - # Create new invitation if none exists - verbose_proxy_logger.debug( - f"Creating new invitation for user_id: {user_id}" - ) - return await create_invitation_for_user( - data=InvitationNew(user_id=user_id), - user_api_key_dict=UserAPIKeyAuth(user_id=user_id), - ) - - except Exception as e: - verbose_proxy_logger.error( - f"Error getting/creating invitation for user_id {user_id}: {e}" - ) - return None - - def _construct_invitation_link(self, invitation_id: str, base_url: str) -> str: - """ - Construct invitation link for the user - - # http://localhost:4000/ui?invitation_id=7a096b3a-37c6-440f-9dd1-ba22e8043f6b - """ - return f"{base_url}/ui?invitation_id={invitation_id}" - - async def send_email( - self, - from_email: str, - to_email: List[str], - subject: str, - html_body: str, - ): - pass diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/endpoints.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/endpoints.py deleted file mode 100644 index 61681c27ee..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/endpoints.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -Endpoints for managing email alerts on litellm -""" - -import json -from typing import Dict - -from fastapi import APIRouter, Depends, HTTPException -from litellm_enterprise.types.enterprise_callbacks.send_emails import ( - DefaultEmailSettings, - EmailEvent, - EmailEventSettings, - EmailEventSettingsResponse, - EmailEventSettingsUpdateRequest, -) - -from litellm._logging import verbose_proxy_logger -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth - -router = APIRouter() - - -async def _get_email_settings(prisma_client) -> Dict[str, bool]: - """Helper function to get email settings from general_settings in db""" - try: - # Get general settings from db - general_settings_entry = await prisma_client.db.litellm_config.find_unique( - where={"param_name": "general_settings"} - ) - - # Initialize with default email settings - settings_dict = DefaultEmailSettings.get_defaults() - - if ( - general_settings_entry is not None - and general_settings_entry.param_value is not None - ): - # Get general settings value - if isinstance(general_settings_entry.param_value, str): - general_settings = json.loads(general_settings_entry.param_value) - else: - general_settings = general_settings_entry.param_value - - # Extract email_settings from general settings if it exists - if general_settings and "email_settings" in general_settings: - email_settings = general_settings["email_settings"] - # Update settings_dict with values from general_settings - for event_name, enabled in email_settings.items(): - settings_dict[event_name] = enabled - - return settings_dict - except Exception as e: - verbose_proxy_logger.error( - f"Error getting email settings from general_settings: {str(e)}" - ) - # Return default settings in case of error - return DefaultEmailSettings.get_defaults() - - -async def _save_email_settings(prisma_client, settings: Dict[str, bool]): - """Helper function to save email settings to general_settings in db""" - try: - verbose_proxy_logger.debug( - f"Saving email settings to general_settings: {settings}" - ) - - # Get current general settings - general_settings_entry = await prisma_client.db.litellm_config.find_unique( - where={"param_name": "general_settings"} - ) - - # Initialize general settings dict - if ( - general_settings_entry is not None - and general_settings_entry.param_value is not None - ): - if isinstance(general_settings_entry.param_value, str): - general_settings = json.loads(general_settings_entry.param_value) - else: - general_settings = dict(general_settings_entry.param_value) - else: - general_settings = {} - - # Update email_settings in general_settings - general_settings["email_settings"] = settings - - # Convert to JSON for storage - json_settings = json.dumps(general_settings, default=str) - - # Save updated general settings - await prisma_client.db.litellm_config.upsert( - where={"param_name": "general_settings"}, - data={ - "create": { - "param_name": "general_settings", - "param_value": json_settings, - }, - "update": {"param_value": json_settings}, - }, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error saving email settings to general_settings: {str(e)}", - ) - - -@router.get( - "/email/event_settings", - response_model=EmailEventSettingsResponse, - tags=["email management"], - dependencies=[Depends(user_api_key_auth)], -) -async def get_email_event_settings( - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Get all email event settings - """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail="Database not connected") - - try: - # Get existing settings - settings_dict = await _get_email_settings(prisma_client) - - # Create a response with all events (enabled or disabled) - response_settings = [] - for event in EmailEvent: - enabled = settings_dict.get(event.value, False) - response_settings.append(EmailEventSettings(event=event, enabled=enabled)) - - return EmailEventSettingsResponse(settings=response_settings) - except Exception as e: - verbose_proxy_logger.exception(f"Error getting email settings: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.patch( - "/email/event_settings", - tags=["email management"], - dependencies=[Depends(user_api_key_auth)], -) -async def update_event_settings( - request: EmailEventSettingsUpdateRequest, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Update the settings for email events - """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail="Database not connected") - - try: - # Get existing settings - settings_dict = await _get_email_settings(prisma_client) - - # Update with new settings - for setting in request.settings: - settings_dict[setting.event.value] = setting.enabled - - # Save updated settings - await _save_email_settings(prisma_client, settings_dict) - - return {"message": "Email event settings updated successfully"} - except Exception as e: - verbose_proxy_logger.exception(f"Error updating email settings: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/email/event_settings/reset", - tags=["email management"], - dependencies=[Depends(user_api_key_auth)], -) -async def reset_event_settings( - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Reset all email event settings to default (new user invitations on, virtual key creation off) - """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail="Database not connected") - - try: - # Reset to default settings using the Pydantic model - default_settings = DefaultEmailSettings.get_defaults() - - # Save default settings - await _save_email_settings(prisma_client, default_settings) - - return {"message": "Email event settings reset to defaults"} - except Exception as e: - verbose_proxy_logger.exception(f"Error resetting email settings: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/resend_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/resend_email.py deleted file mode 100644 index 3fad5601f5..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/resend_email.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -This is the litellm x resend email integration - -https://resend.com/docs/api-reference/emails/send-email -""" - -import os -from typing import List - -from litellm._logging import verbose_logger -from litellm.llms.custom_httpx.http_handler import ( - get_async_httpx_client, - httpxSpecialProvider, -) - -from .base_email import BaseEmailLogger - -RESEND_API_ENDPOINT = "https://api.resend.com/emails" - - -class ResendEmailLogger(BaseEmailLogger): - """ - Send emails using Resend's API. - - Required env vars: - - RESEND_API_KEY - - Optional env vars: - - RESEND_FROM_EMAIL: Override the default sender address. Must be on a - domain verified in your Resend account. When unset, falls back to the - `from_email` argument passed by the caller (which defaults to - `notifications@alerts.litellm.ai` and only works on LiteLLM Cloud). - """ - - def __init__(self, internal_usage_cache=None, **kwargs): - super().__init__(internal_usage_cache=internal_usage_cache, **kwargs) - self.async_httpx_client = get_async_httpx_client( - llm_provider=httpxSpecialProvider.LoggingCallback - ) - self.resend_api_key = os.getenv("RESEND_API_KEY") - self.resend_from_email = os.getenv("RESEND_FROM_EMAIL") - - async def send_email( - self, - from_email: str, - to_email: List[str], - subject: str, - html_body: str, - ): - sender_email = self.resend_from_email or from_email - verbose_logger.debug( - f"Sending email from {sender_email} to {to_email} with subject {subject}" - ) - response = await self.async_httpx_client.post( - url=RESEND_API_ENDPOINT, - json={ - "from": sender_email, - "to": to_email, - "subject": subject, - "html": html_body, - }, - headers={"Authorization": f"Bearer {self.resend_api_key}"}, - ) - verbose_logger.debug( - f"Email sent with status code {response.status_code}. Got response: {response.json()}" - ) - return diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/sendgrid_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/sendgrid_email.py deleted file mode 100644 index 8fc2d66d53..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/sendgrid_email.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -LiteLLM x SendGrid email integration. - -Docs: https://docs.sendgrid.com/api-reference/mail-send/mail-send -""" - -import os -from typing import List - -from litellm._logging import verbose_logger -from litellm.llms.custom_httpx.http_handler import ( - get_async_httpx_client, - httpxSpecialProvider, -) - -from .base_email import BaseEmailLogger - - -SENDGRID_API_ENDPOINT = "https://api.sendgrid.com/v3/mail/send" - - -class SendGridEmailLogger(BaseEmailLogger): - """ - Send emails using SendGrid's Mail Send API. - - Required env vars: - - SENDGRID_API_KEY - """ - - def __init__(self, internal_usage_cache=None, **kwargs): - super().__init__(internal_usage_cache=internal_usage_cache, **kwargs) - self.async_httpx_client = get_async_httpx_client( - llm_provider=httpxSpecialProvider.LoggingCallback - ) - self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY") - self.sendgrid_sender_email = os.getenv("SENDGRID_SENDER_EMAIL") - verbose_logger.debug("SendGrid Email Logger initialized.") - - async def send_email( - self, - from_email: str, - to_email: List[str], - subject: str, - html_body: str, - ): - """ - Send an email via SendGrid. - """ - if not self.sendgrid_api_key: - raise ValueError("SENDGRID_API_KEY is not set") - - sender_email = self.sendgrid_sender_email or from_email - verbose_logger.debug( - f"Sending email via SendGrid from {sender_email} to {to_email} with subject {subject}" - ) - - payload = { - "from": {"email": sender_email}, - "personalizations": [ - { - "to": [{"email": email} for email in to_email], - "subject": subject, - } - ], - "content": [ - { - "type": "text/html", - "value": html_body, - } - ], - } - - response = await self.async_httpx_client.post( - url=SENDGRID_API_ENDPOINT, - json=payload, - headers={"Authorization": f"Bearer {self.sendgrid_api_key}"}, - ) - - verbose_logger.debug( - f"SendGrid response status={response.status_code}, body={response.text}" - ) - return \ No newline at end of file diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/smtp_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/smtp_email.py deleted file mode 100644 index 8efdaf231b..0000000000 --- a/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/smtp_email.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -This is the litellm SMTP email integration -""" -import asyncio -from typing import List - -from litellm._logging import verbose_logger - -from .base_email import BaseEmailLogger - - -class SMTPEmailLogger(BaseEmailLogger): - """ - This is the litellm SMTP email integration - - Required SMTP environment variables: - - SMTP_HOST - - SMTP_PORT - - SMTP_USERNAME - - SMTP_PASSWORD - - SMTP_SENDER_EMAIL - """ - - def __init__(self, internal_usage_cache=None, **kwargs): - super().__init__(internal_usage_cache=internal_usage_cache, **kwargs) - verbose_logger.debug("SMTP Email Logger initialized....") - - async def send_email( - self, - from_email: str, - to_email: List[str], - subject: str, - html_body: str, - ): - from litellm.proxy.utils import send_email as send_smtp_email - - verbose_logger.debug( - f"Sending email from {from_email} to {to_email} with subject {subject}" - ) - for receiver_email in to_email: - asyncio.create_task( - send_smtp_email( - receiver_email=receiver_email, - subject=subject, - html=html_body, - ) - ) - return diff --git a/enterprise/litellm_enterprise/integrations/custom_guardrail.py b/enterprise/litellm_enterprise/integrations/custom_guardrail.py deleted file mode 100644 index f07752d5c1..0000000000 --- a/enterprise/litellm_enterprise/integrations/custom_guardrail.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import List, Optional, Union - -from litellm.types.guardrails import GuardrailEventHooks, Mode - - -class EnterpriseCustomGuardrailHelper: - @staticmethod - def _should_run_if_mode_by_tag( - data: dict, - event_hook: Optional[ - Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode] - ], - event_type: Optional[GuardrailEventHooks] = None, - ) -> Optional[bool]: - """ - Returns True if the guardrail should be run for this request and event_type. - - Logic: - - If a request tag matches a Mode tag key, only run if event_type matches - the tag's value (the mode for that tag). - - If no request tag matches, fall back to default mode(s). - """ - from litellm.litellm_core_utils.litellm_logging import ( - StandardLoggingPayloadSetup, - ) - from litellm.proxy._types import CommonProxyErrors - from litellm.proxy.proxy_server import premium_user - - if not premium_user: - raise Exception( - f"Setting tag based guardrail modes is only available in litellm-enterprise. {CommonProxyErrors.not_premium_user.value}." - ) - - if event_hook is None or not isinstance(event_hook, Mode): - return None - - proxy_server_request = data.get("proxy_server_request", {}) - - request_tags = StandardLoggingPayloadSetup._get_request_tags( - litellm_params=data, - proxy_server_request=proxy_server_request, - ) - - # Check if any request tag matches a Mode tag key - matched_mode = None - if request_tags: - for tag in request_tags: - if tag in event_hook.tags: - matched_mode = event_hook.tags[tag] - break - - if matched_mode is not None: - # Tag matched: only run if event_type matches the tag's mode value(s) - if event_type is not None: - if isinstance(matched_mode, list): - return event_type.value in matched_mode - return event_type.value == matched_mode - return True - - # No tag matched: fall back to default mode(s) - if event_hook.default is not None: - if event_type is not None: - default_list = ( - event_hook.default - if isinstance(event_hook.default, list) - else [event_hook.default] - ) - return event_type.value in default_list - return False - - return False diff --git a/enterprise/litellm_enterprise/litellm_core_utils/litellm_logging.py b/enterprise/litellm_enterprise/litellm_core_utils/litellm_logging.py deleted file mode 100644 index 44ba0063ff..0000000000 --- a/enterprise/litellm_enterprise/litellm_core_utils/litellm_logging.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -Enterprise specific logging utils -""" -from litellm.litellm_core_utils.litellm_logging import StandardLoggingMetadata - - -class StandardLoggingPayloadSetup: - @staticmethod - def apply_enterprise_specific_metadata( - standard_logging_metadata: StandardLoggingMetadata, - proxy_server_request: dict, - ) -> StandardLoggingMetadata: - """ - Adds enterprise-only metadata to the standard logging metadata. - """ - - _request_headers = proxy_server_request.get("headers", {}) - - if _request_headers: - custom_headers = { - k: v - for k, v in _request_headers.items() - if k.startswith("x-") and v is not None and isinstance(v, str) - } - - standard_logging_metadata["requester_custom_headers"] = custom_headers - - return standard_logging_metadata diff --git a/enterprise/litellm_enterprise/proxy/__init__.py b/enterprise/litellm_enterprise/proxy/__init__.py deleted file mode 100644 index 52b74882bc..0000000000 --- a/enterprise/litellm_enterprise/proxy/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Package marker for enterprise proxy components. diff --git a/enterprise/litellm_enterprise/proxy/audit_logging_endpoints.py b/enterprise/litellm_enterprise/proxy/audit_logging_endpoints.py deleted file mode 100644 index 18ac29b978..0000000000 --- a/enterprise/litellm_enterprise/proxy/audit_logging_endpoints.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -AUDIT LOGGING - -All /audit logging endpoints. Attempting to write these as CRUD endpoints. - -GET - /audit/{id} - Get audit log by id -GET - /audit - Get all audit logs -""" - -from typing import Any, Dict, List, Optional - -#### AUDIT LOGGING #### -from fastapi import APIRouter, Depends, HTTPException, Query -from litellm_enterprise.types.proxy.audit_logging_endpoints import ( - AuditLogResponse, - PaginatedAuditLogResponse, -) - -from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth - -router = APIRouter() - - -def _build_json_field_or_condition(json_key: str, value: str) -> Dict[str, Any]: - """ - Build an OR condition that matches a value inside a JSON column at the - given key, checking both before_value and updated_values. - - Uses Prisma's JSON path filtering (PostgreSQL only). - - Example result (team_id="t1"): - {"OR": [ - {"before_value": {"path": ["team_id"], "string_contains": "t1"}}, - {"updated_values": {"path": ["team_id"], "string_contains": "t1"}}, - ]} - """ - return { - "OR": [ - {"before_value": {"path": [json_key], "string_contains": value}}, - {"updated_values": {"path": [json_key], "string_contains": value}}, - ] - } - - -@router.get( - "/audit", - tags=["Audit Logging"], - dependencies=[Depends(user_api_key_auth)], - response_model=PaginatedAuditLogResponse, -) -async def get_audit_logs( - page: int = Query(1, ge=1), - page_size: int = Query(10, ge=1, le=100), - # Filter parameters - changed_by: Optional[str] = Query( - None, description="Filter by user or system that performed the action" - ), - changed_by_api_key: Optional[str] = Query( - None, description="Filter by API key hash that performed the action" - ), - action: Optional[str] = Query( - None, description="Filter by action type (create, update, delete)" - ), - table_name: Optional[str] = Query( - None, description="Filter by table name that was modified" - ), - object_id: Optional[str] = Query( - None, description="Filter by ID of the object that was modified" - ), - start_date: Optional[str] = Query(None, description="Filter logs after this date"), - end_date: Optional[str] = Query(None, description="Filter logs before this date"), - object_team_id: Optional[str] = Query( - None, - description="Filter by team_id present in before_value or updated_values JSON (PostgreSQL only)", - ), - object_key_hash: Optional[str] = Query( - None, - description="Filter by token (key hash) present in before_value or updated_values JSON (PostgreSQL only)", - ), - # Sorting parameters - sort_by: Optional[str] = Query( - None, - description="Column to sort by (e.g. 'updated_at', 'action', 'table_name')", - ), - sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"), -): - """ - Get all audit logs with filtering and pagination. - - Returns a paginated response of audit logs matching the specified filters. - - Note: object_team_id and object_key_hash use Prisma JSON path filtering, - which requires PostgreSQL. - """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"message": CommonProxyErrors.db_not_connected_error.value}, - ) - - # Build filter conditions - where_conditions: Dict[str, Any] = {} - if changed_by: - where_conditions["changed_by"] = changed_by - if changed_by_api_key: - where_conditions["changed_by_api_key"] = changed_by_api_key - if action: - where_conditions["action"] = action - if table_name: - where_conditions["table_name"] = table_name - if object_id: - where_conditions["object_id"] = object_id - if start_date or end_date: - date_filter: Dict[str, Any] = {} - if start_date: - date_filter["gte"] = start_date - if end_date: - date_filter["lte"] = end_date - where_conditions["updated_at"] = date_filter - - # JSON field filters (PostgreSQL only) — each filter is AND'd with the - # others, but checks both before_value and updated_values internally (OR). - if object_team_id: - where_conditions["AND"] = where_conditions.get("AND", []) + [ - _build_json_field_or_condition("team_id", object_team_id) - ] - if object_key_hash: - where_conditions["AND"] = where_conditions.get("AND", []) + [ - _build_json_field_or_condition("token", object_key_hash) - ] - - # Build sort conditions - order_by: Dict[str, Any] = {} - if sort_by and isinstance(sort_by, str): - order_by[sort_by] = sort_order - else: - order_by["updated_at"] = sort_order # Default sort by updated_at - - # Get paginated results - audit_logs = await prisma_client.db.litellm_auditlog.find_many( - where=where_conditions, - order=order_by, - skip=(page - 1) * page_size, - take=page_size, - ) - - # Get total count for pagination - total_count = await prisma_client.db.litellm_auditlog.count(where=where_conditions) - total_pages = -(-total_count // page_size) # Ceiling division - - # Return paginated response - return PaginatedAuditLogResponse( - audit_logs=[ - AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs - ] - if audit_logs - else [], - total=total_count, - page=page, - page_size=page_size, - total_pages=total_pages, - ) - - -@router.get( - "/audit/{id}", - tags=["Audit Logging"], - dependencies=[Depends(user_api_key_auth)], - response_model=AuditLogResponse, - responses={ - 404: {"description": "Audit log not found"}, - 500: {"description": "Database connection error"}, - }, -) -async def get_audit_log_by_id( - id: str, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth) -): - """ - Get detailed information about a specific audit log entry by its ID. - - Args: - id (str): The unique identifier of the audit log entry - - Returns: - AuditLogResponse: Detailed information about the audit log entry - - Raises: - HTTPException: If the audit log is not found or if there's a database connection error - """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"message": CommonProxyErrors.db_not_connected_error.value}, - ) - - # Get the audit log by ID - audit_log = await prisma_client.db.litellm_auditlog.find_unique(where={"id": id}) - - if audit_log is None: - raise HTTPException( - status_code=404, detail={"message": f"Audit log with ID {id} not found"} - ) - - # Convert to response model - return AuditLogResponse(**audit_log.model_dump()) diff --git a/enterprise/litellm_enterprise/proxy/auth/__init__.py b/enterprise/litellm_enterprise/proxy/auth/__init__.py deleted file mode 100644 index f67826ca7f..0000000000 --- a/enterprise/litellm_enterprise/proxy/auth/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -Enterprise Authentication Module for LiteLLM Proxy - -This module contains enterprise-specific authentication functionality, -including custom SSO handlers and advanced authentication features. -""" - -from .custom_sso_handler import EnterpriseCustomSSOHandler - -__all__ = ["EnterpriseCustomSSOHandler"] \ No newline at end of file diff --git a/enterprise/litellm_enterprise/proxy/auth/custom_sso_handler.py b/enterprise/litellm_enterprise/proxy/auth/custom_sso_handler.py deleted file mode 100644 index e8f104c262..0000000000 --- a/enterprise/litellm_enterprise/proxy/auth/custom_sso_handler.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Enterprise Custom SSO Handler for LiteLLM Proxy - -This module contains enterprise-specific custom SSO authentication functionality -that allows users to implement their own SSO handling logic by providing custom -handlers that process incoming request headers and return OpenID objects. - -Use this when you have an OAuth proxy in front of LiteLLM (where the OAuth proxy -has already authenticated the user) and you need to extract user information from -custom headers or other request attributes. -""" - -from typing import cast - -from fastapi import Request -from fastapi.responses import RedirectResponse - - -class EnterpriseCustomSSOHandler: - """ - Enterprise Custom SSO Handler for LiteLLM Proxy - - This class provides methods for handling custom SSO authentication flows - where users can implement their own authentication logic by processing - request headers and returning user information in OpenID format. - """ - - @staticmethod - async def handle_custom_ui_sso_sign_in( - request: Request, - ) -> RedirectResponse: - """ - Allow a user to execute their custom code to parse incoming request headers and return a OpenID object - - Use this when you have an OAuth proxy in front of LiteLLM (where the OAuth proxy has already authenticated the user) - - Args: - request: The FastAPI request object containing headers and other request data - - Returns: - RedirectResponse: Redirect response that sends the user to the LiteLLM UI with authentication token - - Raises: - ValueError: If custom_ui_sso_sign_in_handler is not configured - - Example: - This method is typically called when a user has already been authenticated by an - external OAuth proxy and the proxy has added custom headers containing user information. - The custom handler extracts this information and converts it to an OpenID object. - """ - from fastapi_sso.sso.base import OpenID - - from litellm.integrations.custom_sso_handler import CustomSSOLoginHandler - from litellm.proxy.proxy_server import ( - CommonProxyErrors, - general_settings, - premium_user, - user_custom_ui_sso_sign_in_handler, - ) - from litellm.proxy.auth.trusted_proxy_utils import ( - require_trusted_proxy_request, - ) - - if premium_user is not True: - raise ValueError(CommonProxyErrors.not_premium_user.value) - - if user_custom_ui_sso_sign_in_handler is None: - raise ValueError( - "custom_ui_sso_sign_in_handler is not configured. Please set it in general_settings." - ) - - require_trusted_proxy_request( - request=request, - general_settings=general_settings, - feature_name="Custom UI SSO", - ) - - custom_sso_login_handler = cast( - CustomSSOLoginHandler, user_custom_ui_sso_sign_in_handler - ) - openid_response: OpenID = ( - await custom_sso_login_handler.handle_custom_ui_sso_sign_in( - request=request, - ) - ) - - # Import here to avoid circular imports - from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler - - return await SSOAuthenticationHandler.get_redirect_response_from_openid( - result=openid_response, - request=request, - received_response=None, - generic_client_id=None, - ui_access_mode=None, - ) diff --git a/enterprise/litellm_enterprise/proxy/auth/route_checks.py b/enterprise/litellm_enterprise/proxy/auth/route_checks.py deleted file mode 100644 index fc57292a8d..0000000000 --- a/enterprise/litellm_enterprise/proxy/auth/route_checks.py +++ /dev/null @@ -1,71 +0,0 @@ -import os - -from fastapi import HTTPException, status - - -class EnterpriseRouteChecks: - @staticmethod - def is_llm_api_route_disabled() -> bool: - """ - Check if llm api route is disabled - """ - from litellm.proxy._types import CommonProxyErrors - from litellm.proxy.proxy_server import premium_user - from litellm.secret_managers.main import get_secret_bool - - ## Check if DISABLE_LLM_API_ENDPOINTS is set - if "DISABLE_LLM_API_ENDPOINTS" in os.environ: - if not premium_user: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"🚨🚨🚨 DISABLING LLM API ENDPOINTS is an Enterprise feature\n🚨 {CommonProxyErrors.not_premium_user.value}", - ) - - return get_secret_bool("DISABLE_LLM_API_ENDPOINTS") is True - - @staticmethod - def is_management_routes_disabled() -> bool: - """ - Check if management route is disabled - """ - from litellm.proxy._types import CommonProxyErrors - from litellm.proxy.proxy_server import premium_user - from litellm.secret_managers.main import get_secret_bool - - if "DISABLE_ADMIN_ENDPOINTS" in os.environ: - if not premium_user: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"🚨🚨🚨 DISABLING ADMIN ENDPOINTS is an Enterprise feature\n🚨 {CommonProxyErrors.not_premium_user.value}", - ) - - return get_secret_bool("DISABLE_ADMIN_ENDPOINTS") is True - - # Routes that should remain accessible even when LLM API endpoints are disabled. - # These are read-only model listing routes needed by the Admin UI. - LLM_API_EXEMPT_ROUTES = ["/models", "/v1/models"] - - @staticmethod - def should_call_route(route: str): - """ - Check if management route is disabled and raise exception - """ - from litellm.proxy.auth.route_checks import RouteChecks - - if ( - RouteChecks.is_management_route(route=route) - and EnterpriseRouteChecks.is_management_routes_disabled() - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Management routes are disabled for this instance.", - ) - elif ( - RouteChecks.is_llm_api_route(route=route) - and route not in EnterpriseRouteChecks.LLM_API_EXEMPT_ROUTES - and EnterpriseRouteChecks.is_llm_api_route_disabled() - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="LLM API routes are disabled for this instance.", - ) diff --git a/enterprise/litellm_enterprise/proxy/auth/user_api_key_auth.py b/enterprise/litellm_enterprise/proxy/auth/user_api_key_auth.py deleted file mode 100644 index dc9fdeb78e..0000000000 --- a/enterprise/litellm_enterprise/proxy/auth/user_api_key_auth.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Any, Optional - -from fastapi import Request - -from litellm._logging import verbose_proxy_logger -from litellm.proxy._types import ProxyException, UserAPIKeyAuth - - -async def enterprise_custom_auth( - request: Request, api_key: str, user_custom_auth: Optional[Any] -) -> Optional[UserAPIKeyAuth]: - from litellm_enterprise.proxy.proxy_server import custom_auth_settings - - if user_custom_auth is None: - return None - - if custom_auth_settings is None: - return await user_custom_auth(request, api_key) - - if custom_auth_settings["mode"] == "on": - return await user_custom_auth(request, api_key) - elif custom_auth_settings["mode"] == "off": - return None - elif custom_auth_settings["mode"] == "auto": - try: - return await user_custom_auth(request, api_key) - except ProxyException as e: - raise e - except Exception as e: - verbose_proxy_logger.debug( - f"Error in custom auth, checking litellm auth: {e}" - ) - return None - else: - raise ValueError(f"Invalid mode: {custom_auth_settings['mode']}") diff --git a/enterprise/litellm_enterprise/proxy/common_utils/__init__.py b/enterprise/litellm_enterprise/proxy/common_utils/__init__.py deleted file mode 100644 index fe8384c892..0000000000 --- a/enterprise/litellm_enterprise/proxy/common_utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Package marker for enterprise proxy common utilities. diff --git a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py deleted file mode 100644 index ee7745d0ad..0000000000 --- a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py +++ /dev/null @@ -1,421 +0,0 @@ -""" -Polls LiteLLM_ManagedObjectTable to check if the batch job is complete, and if the cost has been tracked. -""" - -from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, List, Optional, Tuple - -from litellm._logging import verbose_proxy_logger -from litellm._uuid import uuid -from litellm.constants import ( - MANAGED_OBJECT_STALENESS_CUTOFF_DAYS, - MAX_OBJECTS_PER_POLL_CYCLE, -) - -if TYPE_CHECKING: - from litellm.proxy.utils import PrismaClient, ProxyLogging - from litellm.router import Router - - -CHECK_BATCH_COST_USER_AGENT = "LiteLLM Proxy/CheckBatchCost" - - -class CheckBatchCost: - def __init__( - self, - proxy_logging_obj: "ProxyLogging", - prisma_client: "PrismaClient", - llm_router: "Router", - ): - from litellm.proxy.utils import PrismaClient, ProxyLogging - from litellm.router import Router - - self.proxy_logging_obj: ProxyLogging = proxy_logging_obj - self.prisma_client: PrismaClient = prisma_client - self.llm_router: Router = llm_router - # Cached after the first poll cycle. Once we know the column is absent we skip - # the guaranteed-failing primary query on every subsequent cycle. - self._has_batch_processed_column: bool = True - - async def _get_user_info(self, batch_id, user_id) -> dict: - """ - Look up user email and key alias by user_id for enriching the S3 callback metadata. - Returns a dict with user_api_key_user_email and user_api_key_alias (both may be None). - """ - try: - user_row = await self.prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_id} - ) - if user_row is None: - return {} - return { - "user_api_key_user_email": getattr(user_row, "user_email", None), - "user_api_key_alias": getattr(user_row, "user_alias", None), - } - except Exception as e: - verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}") - return {} - - async def _cleanup_stale_managed_objects(self) -> None: - """ - Mark managed objects older than MANAGED_OBJECT_STALENESS_CUTOFF_DAYS days - in non-terminal states as 'stale_expired'. These will never complete and - should not be polled. - """ - cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS) - result = await self.prisma_client.db.litellm_managedobjecttable.update_many( - where={ - "file_purpose": "batch", - "status": {"not_in": ["completed", "complete", "failed", "expired", "cancelled", "stale_expired"]}, - "created_at": {"lt": cutoff}, - }, - data={"status": "stale_expired"}, - ) - if result > 0: - verbose_proxy_logger.warning( - f"CheckBatchCost: marked {result} stale managed objects " - f"(older than {MANAGED_OBJECT_STALENESS_CUTOFF_DAYS} days) as stale_expired" - ) - - async def _fallback_find_jobs(self) -> list: - """Query batch jobs without the batch_processed filter (for older schemas).""" - return await self.prisma_client.db.litellm_managedobjecttable.find_many( - where={ - "file_purpose": "batch", - "status": { - "not_in": [ - "failed", - "expired", - "cancelled", - "complete", - "completed", - "stale_expired", - ] - }, - }, - take=MAX_OBJECTS_PER_POLL_CYCLE, - order={"created_at": "asc"}, - ) - - async def check_batch_cost(self): - """ - Check if the batch JOB has been tracked. - - get all status="validating" and file_purpose="batch" jobs - - check if batch is now complete - - if not, return False - - if so, return True - """ - from litellm.batches.batch_utils import ( - _get_file_content_as_dictionary, - calculate_batch_cost_and_usage, - ) - from litellm.files.main import afile_content - from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider - from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging - from litellm.proxy.openai_files_endpoints.common_utils import ( - _is_base64_encoded_unified_file_id, - get_batch_id_from_unified_batch_id, - get_model_id_from_unified_batch_id, - ) - - try: - from litellm.integrations.prometheus import PrometheusLogger - prom_logger = PrometheusLogger.get_instance() - except Exception as e: - verbose_proxy_logger.error(f"CheckBatchCost: could not get Prometheus logger: {e}") - prom_logger = None - - processed_models: List[Tuple[Optional[str], Optional[str]]] = [] - - try: - await self._cleanup_stale_managed_objects() - except Exception as cleanup_err: - verbose_proxy_logger.warning( - f"CheckBatchCost: stale cleanup failed (poll will continue): {cleanup_err}" - ) - - # Look for all batches that have not yet been processed by CheckBatchCost. - # self._has_batch_processed_column is cached after the first probe so that - # older schemas don't pay a guaranteed-failing primary query + warning on - # every subsequent poll cycle. - if self._has_batch_processed_column: - try: - # Include "complete"/"completed" batches: the retrieve_batch - # endpoint may transition a batch to "complete" before - # CheckBatchCost runs. The batch_processed=False filter - # already prevents reprocessing finished batches. - jobs = await self.prisma_client.db.litellm_managedobjecttable.find_many( - where={ - "file_purpose": "batch", - "batch_processed": False, - "status": { - "not_in": [ - "failed", - "expired", - "cancelled", - "stale_expired", - ] - }, - }, - take=MAX_OBJECTS_PER_POLL_CYCLE, - order={"created_at": "asc"}, - ) - except Exception as query_err: - if "batch_processed" not in str(query_err).lower() and "unknown column" not in str(query_err).lower() and "does not exist" not in str(query_err).lower(): - raise - # Permanent schema gap — cache the result so future cycles skip straight to fallback - self._has_batch_processed_column = False - verbose_proxy_logger.warning( - "CheckBatchCost: batch_processed column not found, querying without it" - ) - jobs = await self._fallback_find_jobs() - else: - jobs = await self._fallback_find_jobs() - for job in jobs: - # get the model from the job - unified_object_id = job.unified_object_id - decoded_unified_object_id = _is_base64_encoded_unified_file_id( - unified_object_id - ) - if not decoded_unified_object_id: - verbose_proxy_logger.info( - f"Skipping job {unified_object_id} because it is not a valid unified object id" - ) - if prom_logger: - prom_logger.record_check_batch_cost_error("invalid_unified_id") - continue - else: - unified_object_id = decoded_unified_object_id - - model_id = get_model_id_from_unified_batch_id(unified_object_id) - batch_id = get_batch_id_from_unified_batch_id(unified_object_id) - - if model_id is None: - verbose_proxy_logger.info( - f"Skipping job {unified_object_id} because it is not a valid model id" - ) - if prom_logger: - prom_logger.record_check_batch_cost_error("invalid_model_id") - continue - - verbose_proxy_logger.info( - f"Querying model ID: {model_id} for cost and usage of batch ID: {batch_id}" - ) - - try: - response = await self.llm_router.aretrieve_batch( - model=model_id, - batch_id=batch_id, - litellm_metadata={ - "user_api_key_user_id": job.created_by or "default-user-id", - "batch_ignore_default_logging": True, - }, - ) - except Exception as e: - verbose_proxy_logger.info( - f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}" - ) - if prom_logger: - prom_logger.record_check_batch_cost_error("provider_retrieval_error") - continue - - ## RETRIEVE THE BATCH JOB OUTPUT FILE - if ( - response.status == "completed" - and response.output_file_id is not None - ): - verbose_proxy_logger.info( - f"Batch ID: {batch_id} is complete, tracking cost and usage" - ) - - # aretrieve_batch is called with the raw provider batch ID, so response.id - # is the raw provider value (e.g. "batch_20260223-0518.234"). We need the - # unified base64 ID in the S3 log so downstream consumers can correlate it - # back to the batch they submitted via the proxy. - # - # CheckBatchCost builds its own LiteLLMLogging object (logging_obj below) and - # calls async_success_handler(result=response) directly. That handler calls - # _build_standard_logging_payload(response, ...) which reads response.id at - # that point — so setting response.id here is sufficient. - # - # The HTTP endpoint does this substitution via the managed files hook - # (async_post_call_success_hook). CheckBatchCost bypasses that hook entirely, - # so we do it explicitly here. - response.id = job.unified_object_id - - # This background job runs as default_user_id, so going through the HTTP endpoint - # would trigger check_managed_file_id_access and get 403. Instead, extract the raw - # provider file ID and call afile_content directly with deployment credentials. - raw_output_file_id = response.output_file_id - decoded = _is_base64_encoded_unified_file_id(raw_output_file_id) - if decoded: - try: - raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0] - except (IndexError, AttributeError): - pass - - credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {} - _file_content = await afile_content( - file_id=raw_output_file_id, - **credentials, - ) - - # Access content - handle both direct attribute and method call - if hasattr(_file_content, 'content'): - content_bytes = _file_content.content # type: ignore[union-attr] - elif hasattr(_file_content, 'read'): - content_bytes = await _file_content.read() # type: ignore[misc] - else: - content_bytes = _file_content # type: ignore[assignment] - - file_content_as_dict = _get_file_content_as_dictionary( - content_bytes # type: ignore[arg-type] - ) - - # Record output file size - if prom_logger and content_bytes: - try: - prom_logger.record_managed_file_size( - size_bytes=len(content_bytes), # type: ignore - purpose="batch", - file_type="output", - model=model_id, - ) - except Exception: - pass - - deployment_info = self.llm_router.get_deployment(model_id=model_id) - if deployment_info is None: - verbose_proxy_logger.info( - f"Skipping job {unified_object_id} because it is not a valid deployment info" - ) - if prom_logger: - prom_logger.record_check_batch_cost_error("deployment_not_found") - continue - custom_llm_provider = deployment_info.litellm_params.custom_llm_provider - litellm_model_name = deployment_info.litellm_params.model - - model_name, llm_provider, _, _ = get_llm_provider( - model=litellm_model_name, - custom_llm_provider=custom_llm_provider, - ) - - # CheckBatchCost bypasses async_post_call_success_hook, so convert raw - # output/error file IDs to managed base64 IDs before the DB write here. - managed_files_hook = self.proxy_logging_obj.get_proxy_hook("managed_files") - if managed_files_hook is not None: - from litellm.proxy._types import UserAPIKeyAuth - _minimal_auth = UserAPIKeyAuth( - user_id=job.created_by or "default-user-id", - team_id=getattr(job, "team_id", None), - ) - for _file_attr in ["output_file_id", "error_file_id"]: - _raw_file_id = getattr(response, _file_attr, None) - if _raw_file_id and not _is_base64_encoded_unified_file_id(_raw_file_id): - try: - _unified_file_id = managed_files_hook.get_unified_output_file_id( - output_file_id=_raw_file_id, - model_id=model_id, - model_name=str(model_name) if model_name else deployment_info.model_name or None, - ) - await managed_files_hook.store_unified_file_id( - file_id=_unified_file_id, - file_object=None, - litellm_parent_otel_span=None, - model_mappings={model_id: _raw_file_id}, - user_api_key_dict=_minimal_auth, - ) - setattr(response, _file_attr, _unified_file_id) - verbose_proxy_logger.info( - f"CheckBatchCost: converted {_file_attr} " - f"{_raw_file_id!r} -> managed ID for batch {batch_id}" - ) - except Exception as _e: - verbose_proxy_logger.warning( - f"CheckBatchCost: failed to create managed file ID for " - f"{_file_attr}={_raw_file_id!r}: {_e}" - ) - - # Pass deployment model_info so custom batch pricing - # (input_cost_per_token_batches etc.) is used for cost calc - deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {} - batch_cost, batch_usage, batch_models = ( - await calculate_batch_cost_and_usage( - file_content_dictionary=file_content_as_dict, - custom_llm_provider=llm_provider, # type: ignore - model_name=model_name, - model_info=deployment_model_info, # type: ignore[arg-type] - ) - ) - logging_obj = LiteLLMLogging( - model=batch_models[0], - messages=[{"role": "user", "content": ""}], - stream=False, - call_type="aretrieve_batch", - start_time=datetime.now(), - litellm_call_id=str(uuid.uuid4()), - function_id=str(uuid.uuid4()), - ) - - creator_user_id = job.created_by - user_info = await self._get_user_info(batch_id, job.created_by) - - logging_obj.update_environment_variables( - litellm_params={ - # set the user-agent header so that S3 callback consumers can easily identify CheckBatchCost callbacks - "proxy_server_request": { - "headers": { - "user-agent": CHECK_BATCH_COST_USER_AGENT, - } - }, - "metadata": { - "user_api_key_user_id": creator_user_id, - **user_info, - }, - }, - optional_params={}, - ) - - await logging_obj.async_success_handler( - result=response, - batch_cost=batch_cost, - batch_usage=batch_usage, - batch_models=batch_models, - ) - - # Record batch duration (completed_at - created_at) - if prom_logger and response.completed_at and response.created_at: - duration_seconds = float(response.completed_at - response.created_at) - if duration_seconds >= 0: - prom_logger.record_managed_batch_duration( - duration_seconds=duration_seconds, - model=model_name, - api_provider=str(llm_provider) if llm_provider else None, - ) - - # Track this job for the final metrics summary - processed_models.append((model_name, str(llm_provider) if llm_provider else None)) - - # mark the job as complete - try: - update_data: dict = { - "status": "complete", - "file_object": response.model_dump_json(), - } - if self._has_batch_processed_column: - update_data["batch_processed"] = True - await self.prisma_client.db.litellm_managedobjecttable.update( - where={"id": job.id}, - data=update_data, - ) - except Exception as db_err: - verbose_proxy_logger.error( - f"CheckBatchCost: failed to mark job {job.id} complete in DB: {db_err}" - ) - - # Record polling run metrics (always, even if nothing was processed) - if prom_logger: - prom_logger.record_check_batch_cost_run( - jobs_polled=len(jobs), - processed_models=processed_models if processed_models else None, - ) diff --git a/enterprise/litellm_enterprise/proxy/common_utils/check_responses_cost.py b/enterprise/litellm_enterprise/proxy/common_utils/check_responses_cost.py deleted file mode 100644 index dc0168683c..0000000000 --- a/enterprise/litellm_enterprise/proxy/common_utils/check_responses_cost.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -Polls LiteLLM_ManagedObjectTable to check if the response is complete. -Cost tracking is handled automatically by litellm.aget_responses(). -""" - -from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.constants import ( - MANAGED_OBJECT_STALENESS_CUTOFF_DAYS, - MAX_OBJECTS_PER_POLL_CYCLE, - STALE_OBJECT_CLEANUP_BATCH_SIZE, -) - -if TYPE_CHECKING: - from litellm.proxy.utils import PrismaClient, ProxyLogging - from litellm.router import Router - - -class CheckResponsesCost: - def __init__( - self, - proxy_logging_obj: "ProxyLogging", - prisma_client: "PrismaClient", - llm_router: "Router", - ): - from litellm.proxy.utils import PrismaClient, ProxyLogging - from litellm.router import Router - - self.proxy_logging_obj: ProxyLogging = proxy_logging_obj - self.prisma_client: PrismaClient = prisma_client - self.llm_router: Router = llm_router - - async def _expire_stale_rows( - self, cutoff: datetime, batch_size: int - ) -> int: - """Execute the bounded UPDATE that marks stale rows as 'stale_expired'. - - Isolated so it can be swapped / mocked in tests without touching the - orchestration logic in ``_cleanup_stale_managed_objects``. - - Uses PostgreSQL syntax (``$1::timestamptz``, ``LIMIT``, double-quoted - identifiers) which is the only dialect the proxy supports — every - ``schema.prisma`` in the repo sets ``provider = "postgresql"``. - Same pattern as ``spend_log_cleanup.py``. - """ - return await self.prisma_client.db.execute_raw( - """ - UPDATE "LiteLLM_ManagedObjectTable" - SET "status" = 'stale_expired' - WHERE "id" IN ( - SELECT "id" FROM "LiteLLM_ManagedObjectTable" - WHERE "file_purpose" = 'response' - AND "status" NOT IN ('completed', 'complete', 'failed', 'expired', 'cancelled', 'stale_expired') - AND "created_at" < $1::timestamptz - ORDER BY "created_at" ASC - LIMIT $2 - ) - """, - cutoff, - batch_size, - ) - - async def _cleanup_stale_managed_objects(self) -> None: - """ - Mark managed objects older than MANAGED_OBJECT_STALENESS_CUTOFF_DAYS days - in non-terminal states as 'stale_expired'. These will never complete and - should not be polled. - - Runs as a single DB query with a subquery LIMIT so no rows are loaded - into Python memory. Processes at most STALE_OBJECT_CLEANUP_BATCH_SIZE - rows per invocation to avoid overwhelming the DB when there is a large - backlog. - """ - cutoff = datetime.now(timezone.utc) - timedelta(days=MANAGED_OBJECT_STALENESS_CUTOFF_DAYS) - result = await self._expire_stale_rows(cutoff, STALE_OBJECT_CLEANUP_BATCH_SIZE) - if result > 0: - verbose_proxy_logger.warning( - f"CheckResponsesCost: marked {result} stale managed objects " - f"(older than {MANAGED_OBJECT_STALENESS_CUTOFF_DAYS} days) as stale_expired" - ) - - async def check_responses_cost(self): - """ - Check if background responses are complete and track their cost. - - Get all status="queued" or "in_progress" and file_purpose="response" jobs - - Query the provider to check if response is complete - - Cost is automatically tracked by litellm.aget_responses() - - Mark completed/failed/cancelled responses as complete in the database - """ - try: - await self._cleanup_stale_managed_objects() - except Exception as cleanup_err: - verbose_proxy_logger.warning( - f"CheckResponsesCost: stale cleanup failed (poll will continue): {cleanup_err}" - ) - - jobs = await self.prisma_client.db.litellm_managedobjecttable.find_many( - where={ - "status": {"in": ["queued", "in_progress"]}, - "file_purpose": "response", - }, - take=MAX_OBJECTS_PER_POLL_CYCLE, - order={"created_at": "asc"}, - ) - - verbose_proxy_logger.debug(f"Found {len(jobs)} response jobs to check") - completed_jobs = [] - - for job in jobs: - unified_object_id = job.unified_object_id - - try: - from litellm.proxy.hooks.responses_id_security import ( - ResponsesIDSecurity, - ) - - # Get the stored response object to extract model information - stored_response = job.file_object - model_name = stored_response.get("model", None) - - # Decrypt the response ID - responses_id_security, _, _ = ResponsesIDSecurity()._decrypt_response_id(unified_object_id) - - # Prepare metadata with model information for cost tracking - litellm_metadata = { - "user_api_key_user_id": job.created_by or "default-user-id", - } - - # Add model information if available - if model_name: - litellm_metadata["model"] = model_name - litellm_metadata["model_group"] = model_name # Use same value for model_group - - response = await litellm.aget_responses( - response_id=responses_id_security, - litellm_metadata=litellm_metadata, - ) - - verbose_proxy_logger.debug( - f"Response {unified_object_id} status: {response.status}, model: {model_name}" - ) - - except Exception as e: - verbose_proxy_logger.info( - f"Skipping job {unified_object_id} due to error: {e}" - ) - continue - - # Check if response is in a terminal state - if response.status == "completed": - verbose_proxy_logger.info( - f"Response {unified_object_id} is complete. Cost automatically tracked by aget_responses." - ) - completed_jobs.append(job) - - elif response.status in ["failed", "cancelled"]: - verbose_proxy_logger.info( - f"Response {unified_object_id} has status {response.status}, marking as complete" - ) - completed_jobs.append(job) - - # Mark completed jobs in the database - if len(completed_jobs) > 0: - await self.prisma_client.db.litellm_managedobjecttable.update_many( - where={"id": {"in": [job.id for job in completed_jobs]}}, - data={"status": "completed"}, - ) - verbose_proxy_logger.info( - f"Marked {len(completed_jobs)} response jobs as completed" - ) - diff --git a/enterprise/litellm_enterprise/proxy/enterprise_routes.py b/enterprise/litellm_enterprise/proxy/enterprise_routes.py deleted file mode 100644 index ec37c04980..0000000000 --- a/enterprise/litellm_enterprise/proxy/enterprise_routes.py +++ /dev/null @@ -1,29 +0,0 @@ -from fastapi import APIRouter -from fastapi.responses import Response -from litellm_enterprise.enterprise_callbacks.send_emails.endpoints import ( - router as email_events_router, -) - -from . import ui_crud_endpoints # side-effect: registers extra UI settings -from .audit_logging_endpoints import router as audit_logging_router -from .management_endpoints import management_endpoints_router -from .utils import _should_block_robots - -__all__ = ["router", "ui_crud_endpoints"] - -router = APIRouter() -router.include_router(email_events_router) -router.include_router(audit_logging_router) -router.include_router(management_endpoints_router) - - -@router.get("/robots.txt") -async def get_robots(): - """ - Block all web crawlers from indexing the proxy server endpoints - This is useful for ensuring that the API endpoints aren't indexed by search engines - """ - if _should_block_robots(): - return Response(content="User-agent: *\nDisallow: /", media_type="text/plain") - else: - return Response(status_code=404) diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py deleted file mode 100644 index ae5905f9cd..0000000000 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ /dev/null @@ -1,1664 +0,0 @@ -# What is this? -## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id - -import asyncio -import base64 -import json -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast - -from fastapi import HTTPException - -import litellm -from litellm import Router, verbose_logger -from litellm._uuid import uuid -from litellm.caching.caching import DualCache -from litellm.integrations.custom_logger import CustomLogger -from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data -from litellm.llms.base_llm.files.transformation import BaseFileEndpoints -from litellm.llms.base_llm.managed_resources.isolation import ( - build_list_page, - build_owner_filter, - can_access_resource, -) -from litellm.proxy._types import ( - CallTypes, - LiteLLM_ManagedFileTable, - LiteLLM_ManagedObjectTable, - UserAPIKeyAuth, -) -from litellm.proxy.openai_files_endpoints.common_utils import ( - _is_base64_encoded_unified_file_id, - get_batch_id_from_unified_batch_id, - get_content_type_from_file_object, - get_model_id_from_unified_batch_id, - get_models_from_unified_file_id, - normalize_mime_type_for_provider, -) -from litellm.types.llms.openai import ( # pyright: ignore[reportAttributeAccessIssue] - AllMessageValues, - AsyncCursorPage, - ChatCompletionFileObject, - CreateFileRequest, - FileObject, - OpenAIFileObject, - OpenAIFilesPurpose, - ResponsesAPIResponse, -) -from litellm.types.utils import ( - CallTypesLiteral, - LiteLLMBatch, - LiteLLMFineTuningJob, - LLMResponseTypes, - SpecialEnums, -) - -if TYPE_CHECKING: - from litellm.types.llms.openai import HttpxBinaryResponseContent - - -if TYPE_CHECKING: - from opentelemetry.trace import Span as _Span - - from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache - from litellm.proxy.utils import PrismaClient as _PrismaClient - - Span = Union[_Span, Any] - InternalUsageCache = _InternalUsageCache - PrismaClient = _PrismaClient -else: - Span = Any - InternalUsageCache = Any - PrismaClient = Any - - -class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): - # Class variables or attributes - def __init__( - self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient - ): - self.internal_usage_cache = internal_usage_cache - self.prisma_client = prisma_client - - @staticmethod - def _get_prometheus_logger(): - """Find PrometheusLogger from litellm.callbacks, if registered.""" - from litellm.integrations.prometheus import PrometheusLogger - - return PrometheusLogger.get_instance() - - async def store_unified_file_id( - self, - file_id: str, - file_object: Optional[OpenAIFileObject], - litellm_parent_otel_span: Optional[Span], - model_mappings: Dict[str, str], - user_api_key_dict: UserAPIKeyAuth, - ) -> None: - verbose_logger.info( - f"Storing LiteLLM Managed File object with id={file_id} in cache" - ) - if file_object is not None: - litellm_managed_file_object = LiteLLM_ManagedFileTable( - unified_file_id=file_id, - file_object=file_object, - model_mappings=model_mappings, - flat_model_file_ids=list(model_mappings.values()), - created_by=user_api_key_dict.user_id, - team_id=user_api_key_dict.team_id, - updated_by=user_api_key_dict.user_id, - ) - await self.internal_usage_cache.async_set_cache( - key=file_id, - value=litellm_managed_file_object.model_dump(), - litellm_parent_otel_span=litellm_parent_otel_span, - ) - - ## STORE MODEL MAPPINGS IN DB - - db_data = { - "unified_file_id": file_id, - "model_mappings": json.dumps(model_mappings), - "flat_model_file_ids": list(model_mappings.values()), - "created_by": user_api_key_dict.user_id, - "team_id": user_api_key_dict.team_id, - "updated_by": user_api_key_dict.user_id, - } - - if file_object is not None: - db_data["file_object"] = file_object.model_dump_json() - # Extract storage metadata from hidden params if present - hidden_params = getattr(file_object, "_hidden_params", {}) or {} - if "storage_backend" in hidden_params: - db_data["storage_backend"] = hidden_params["storage_backend"] - if "storage_url" in hidden_params: - db_data["storage_url"] = hidden_params["storage_url"] - - verbose_logger.debug( - f"Storage metadata: storage_backend={db_data.get('storage_backend')}, " - f"storage_url={db_data.get('storage_url')}" - ) - - result = await self.prisma_client.db.litellm_managedfiletable.create( - data=db_data - ) - verbose_logger.debug( - f"LiteLLM Managed File object with id={file_id} stored in db: {result}" - ) - - async def store_unified_object_id( - self, - unified_object_id: str, - file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob, "ResponsesAPIResponse"], - litellm_parent_otel_span: Optional[Span], - model_object_id: str, - file_purpose: Literal["batch", "fine-tune", "response"], - user_api_key_dict: UserAPIKeyAuth, - ) -> None: - verbose_logger.info( - f"Storing LiteLLM Managed {file_purpose} object with id={unified_object_id} in cache" - ) - litellm_managed_object = LiteLLM_ManagedObjectTable( - unified_object_id=unified_object_id, - model_object_id=model_object_id, - file_purpose=file_purpose, - file_object=file_object, - ) - await self.internal_usage_cache.async_set_cache( - key=unified_object_id, - value=litellm_managed_object.model_dump(), - litellm_parent_otel_span=litellm_parent_otel_span, - ) - - await self.prisma_client.db.litellm_managedobjecttable.upsert( - where={"unified_object_id": unified_object_id}, - data={ - "create": { - "unified_object_id": unified_object_id, - "file_object": file_object.model_dump_json(), - "model_object_id": model_object_id, - "file_purpose": file_purpose, - "created_by": user_api_key_dict.user_id, - "team_id": user_api_key_dict.team_id, - "updated_by": user_api_key_dict.user_id, - "status": file_object.status, - }, - "update": { - "file_object": file_object.model_dump_json(), - "status": file_object.status, - "updated_by": user_api_key_dict.user_id, - }, # FIX: Update status and file_object on every operation to keep state in sync - }, - ) - - async def get_unified_file_id( - self, file_id: str, litellm_parent_otel_span: Optional[Span] = None - ) -> Optional[LiteLLM_ManagedFileTable]: - ## CHECK CACHE - result = cast( - Optional[dict], - await self.internal_usage_cache.async_get_cache( - key=file_id, - litellm_parent_otel_span=litellm_parent_otel_span, - ), - ) - - if result: - return LiteLLM_ManagedFileTable(**result) - - ## CHECK DB - db_object = await self.prisma_client.db.litellm_managedfiletable.find_first( - where={"unified_file_id": file_id} - ) - - if db_object: - return LiteLLM_ManagedFileTable(**db_object.model_dump()) - return None - - async def delete_unified_file_id( - self, file_id: str, litellm_parent_otel_span: Optional[Span] = None - ) -> OpenAIFileObject: - ## get old value - initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first( - where={"unified_file_id": file_id} - ) - if initial_value is None: - raise Exception(f"LiteLLM Managed File object with id={file_id} not found") - ## delete old value - await self.internal_usage_cache.async_set_cache( - key=file_id, - value=None, - litellm_parent_otel_span=litellm_parent_otel_span, - ) - await self.prisma_client.db.litellm_managedfiletable.delete( - where={"unified_file_id": file_id} - ) - return initial_value.file_object - - async def can_user_call_unified_file_id( - self, unified_file_id: str, user_api_key_dict: UserAPIKeyAuth - ) -> bool: - managed_file = await self.prisma_client.db.litellm_managedfiletable.find_first( - where={"unified_file_id": unified_file_id} - ) - - if managed_file: - return can_access_resource( - user_api_key_dict=user_api_key_dict, - created_by=managed_file.created_by, - resource_team_id=managed_file.team_id, - ) - raise HTTPException( - status_code=404, - detail=f"File not found: {unified_file_id}", - ) - - async def can_user_call_unified_object_id( - self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth - ) -> bool: - managed_object = ( - await self.prisma_client.db.litellm_managedobjecttable.find_first( - where={"unified_object_id": unified_object_id} - ) - ) - - if managed_object: - return can_access_resource( - user_api_key_dict=user_api_key_dict, - created_by=managed_object.created_by, - resource_team_id=managed_object.team_id, - ) - raise HTTPException( - status_code=404, - detail=f"Object not found: {unified_object_id}", - ) - - async def list_user_batches( - self, - user_api_key_dict: UserAPIKeyAuth, - limit: Optional[int] = None, - after: Optional[str] = None, - provider: Optional[str] = None, - target_model_names: Optional[str] = None, - llm_router: Optional[Router] = None, - ) -> Dict[str, Any]: - # Provider filtering is not supported for managed batches - # This is because the encoded object ids stored in the managed objects table do not contain the provider information - # To support provider filtering, we would need to store the provider information in the encoded object ids - if provider: - raise Exception( - "Filtering by 'provider' is not supported when using managed batches." - ) - - # Model name filtering is not supported for managed batches - # This is because the encoded object ids stored in the managed objects table do not contain the model name - # A hash of the model name + litellm_params for the model name is encoded as the model id. This is not sufficient to reliably map the target model names to the model ids. - if target_model_names: - raise Exception( - "Filtering by 'target_model_names' is not supported when using managed batches." - ) - - owner_filter = build_owner_filter(user_api_key_dict) - if owner_filter is None: - return build_list_page([]) - - where_clause: Dict[str, Any] = {"file_purpose": "batch", **owner_filter} - - if after: - where_clause["id"] = {"gt": after} - - fetch_limit = limit or 20 - if target_model_names: - # Oversample so post-fetch model-name filtering still has enough rows. - fetch_limit = max(fetch_limit * 3, 100) - - batches = await self.prisma_client.db.litellm_managedobjecttable.find_many( - where=where_clause, - take=fetch_limit, - order={"created_at": "desc"}, - ) - - batch_objects: List[LiteLLMBatch] = [] - for batch in batches: - try: - # Stop once we have enough after filtering - if len(batch_objects) >= (limit or 20): - break - - batch_data = ( - json.loads(batch.file_object) - if isinstance(batch.file_object, str) - else batch.file_object - ) - batch_obj = LiteLLMBatch(**batch_data) - batch_obj.id = batch.unified_object_id - batch_objects.append(batch_obj) - - except Exception as e: - verbose_logger.warning( - f"Failed to parse batch object {batch.unified_object_id}: {e}" - ) - continue - - return build_list_page( - batch_objects, has_more=len(batch_objects) == (limit or 20) - ) - - async def get_user_created_file_ids( - self, user_api_key_dict: UserAPIKeyAuth, model_object_ids: List[str] - ) -> List[OpenAIFileObject]: - """ - Get all file ids the caller is allowed to see for a list of model - object ids. Service-account keys (no user_id) are scoped to their - team via ``team_id``; admins see all matches. - - Returns: - - List of OpenAIFileObject's - """ - owner_filter = build_owner_filter(user_api_key_dict) - if owner_filter is None: - return [] - - file_ids = await self.prisma_client.db.litellm_managedfiletable.find_many( - where={ - **owner_filter, - "flat_model_file_ids": {"hasSome": model_object_ids}, - } - ) - return [OpenAIFileObject(**file_object.file_object) for file_object in file_ids] - - async def check_managed_file_id_access( - self, data: Dict, user_api_key_dict: UserAPIKeyAuth - ) -> bool: - retrieve_file_id = cast(Optional[str], data.get("file_id")) - potential_file_id = ( - _is_base64_encoded_unified_file_id(retrieve_file_id) - if retrieve_file_id - else False - ) - if potential_file_id and retrieve_file_id: - if await self.can_user_call_unified_file_id( - retrieve_file_id, user_api_key_dict - ): - return True - else: - raise HTTPException( - status_code=403, - detail=f"User {user_api_key_dict.user_id} does not have access to the file {retrieve_file_id}", - ) - return False - - async def check_file_ids_access( - self, file_ids: List[str], user_api_key_dict: UserAPIKeyAuth - ) -> None: - """ - Check if the user has access to a list of file IDs. - Only checks managed (unified) file IDs. - - Args: - file_ids: List of file IDs to check access for - user_api_key_dict: User API key authentication details - - Raises: - HTTPException: If user doesn't have access to any of the files - """ - for file_id in file_ids: - is_unified_file_id = _is_base64_encoded_unified_file_id(file_id) - if is_unified_file_id: - if not await self.can_user_call_unified_file_id( - file_id, user_api_key_dict - ): - raise HTTPException( - status_code=403, - detail=f"User {user_api_key_dict.user_id} does not have access to the file {file_id}", - ) - - async def async_pre_call_hook( # noqa: PLR0915 - self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: Dict, - call_type: CallTypesLiteral, - ) -> Union[Exception, str, Dict, None]: - """ - - Detect litellm_proxy/ file_id - - add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}} - """ - ### HANDLE FILE ACCESS ### - ensure user has access to the file - if ( - call_type == CallTypes.afile_content.value - or call_type == CallTypes.afile_delete.value - or call_type == CallTypes.afile_retrieve.value - or call_type == CallTypes.afile_content.value - ): - await self.check_managed_file_id_access(data, user_api_key_dict) - - ### HANDLE TRANSFORMATIONS ### - # Check both completion and acompletion call types - is_completion_call = ( - call_type == CallTypes.completion.value - or call_type == CallTypes.acompletion.value - ) - - if is_completion_call: - messages = data.get("messages") - model = data.get("model", "") - if messages: - file_ids = self.get_file_ids_from_messages(messages) - if file_ids: - # Check user has access to all managed files - await self.check_file_ids_access(file_ids, user_api_key_dict) - - # Check if any files are stored in storage backends and need base64 conversion - # This is needed for Vertex AI/Gemini which requires base64 content - is_vertex_ai = model and ( - "vertex_ai" in model or "gemini" in model.lower() - ) - if is_vertex_ai: - await self._convert_storage_files_to_base64( - messages=messages, - file_ids=file_ids, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) - - model_file_id_mapping = await self.get_model_file_id_mapping( - file_ids, user_api_key_dict.parent_otel_span - ) - data["model_file_id_mapping"] = model_file_id_mapping - elif ( - call_type == CallTypes.aresponses.value - or call_type == CallTypes.responses.value - ): - # Handle managed files in responses API input and tools - file_ids = [] - - # Extract file IDs from input parameter - input_data = data.get("input") - if input_data: - file_ids.extend(self.get_file_ids_from_responses_input(input_data)) - - # Extract file IDs from tools parameter (e.g., code_interpreter container) - tools = data.get("tools") - if tools: - file_ids.extend(self.get_file_ids_from_responses_tools(tools)) - - if file_ids: - # Check user has access to all managed files - await self.check_file_ids_access(file_ids, user_api_key_dict) - - model_file_id_mapping = await self.get_model_file_id_mapping( - file_ids, user_api_key_dict.parent_otel_span - ) - data["model_file_id_mapping"] = model_file_id_mapping - - # Check access for file_search vector_store_ids - if tools: - unified_vs_ids = self.get_vector_store_ids_from_file_search_tools(tools) - if unified_vs_ids: - await self.check_vector_store_ids_access( - unified_vs_ids, user_api_key_dict - ) - elif call_type == CallTypes.afile_content.value: - retrieve_file_id = cast(Optional[str], data.get("file_id")) - potential_file_id = ( - _is_base64_encoded_unified_file_id(retrieve_file_id) - if retrieve_file_id - else False - ) - if potential_file_id: - model_id = self.get_model_id_from_unified_file_id(potential_file_id) - if model_id: - data["model"] = model_id - data["file_id"] = self.get_output_file_id_from_unified_file_id( - potential_file_id - ) - elif call_type == CallTypes.acreate_batch.value: - input_file_id = cast(Optional[str], data.get("input_file_id")) - if input_file_id: - model_file_id_mapping = await self.get_model_file_id_mapping( - [input_file_id], user_api_key_dict.parent_otel_span - ) - - data["model_file_id_mapping"] = model_file_id_mapping - elif ( - call_type == CallTypes.aretrieve_batch.value - or call_type == CallTypes.acancel_batch.value - or call_type == CallTypes.acancel_fine_tuning_job.value - or call_type == CallTypes.aretrieve_fine_tuning_job.value - ): - accessor_key: Optional[str] = None - retrieve_object_id: Optional[str] = None - if ( - call_type == CallTypes.aretrieve_batch.value - or call_type == CallTypes.acancel_batch.value - ): - accessor_key = "batch_id" - elif ( - call_type == CallTypes.acancel_fine_tuning_job.value - or call_type == CallTypes.aretrieve_fine_tuning_job.value - ): - accessor_key = "fine_tuning_job_id" - - if accessor_key: - retrieve_object_id = cast(Optional[str], data.get(accessor_key)) - - potential_llm_object_id = ( - _is_base64_encoded_unified_file_id(retrieve_object_id) - if retrieve_object_id - else False - ) - if potential_llm_object_id and retrieve_object_id: - ## VALIDATE USER HAS ACCESS TO THE OBJECT ## - if not await self.can_user_call_unified_object_id( - retrieve_object_id, user_api_key_dict - ): - raise HTTPException( - status_code=403, - detail=f"User {user_api_key_dict.user_id} does not have access to the object {retrieve_object_id}", - ) - - ## for managed batch id - get the model id - potential_model_id = get_model_id_from_unified_batch_id( - potential_llm_object_id - ) - if potential_model_id is None: - raise Exception( - f"LiteLLM Managed {accessor_key} with id={retrieve_object_id} is invalid - does not contain encoded model_id." - ) - data["model"] = potential_model_id - data[accessor_key] = get_batch_id_from_unified_batch_id( - potential_llm_object_id - ) - elif call_type == CallTypes.acreate_fine_tuning_job.value: - input_file_id = cast(Optional[str], data.get("training_file")) - if input_file_id: - model_file_id_mapping = await self.get_model_file_id_mapping( - [input_file_id], user_api_key_dict.parent_otel_span - ) - - return data - - async def async_filter_deployments( - self, - model: str, - healthy_deployments: List, - messages: Optional[List[AllMessageValues]], - request_kwargs: Optional[Dict] = None, - parent_otel_span: Optional[Span] = None, - ) -> List[Dict]: - if request_kwargs is None: - return healthy_deployments - - input_file_id = cast(Optional[str], request_kwargs.get("input_file_id")) - model_file_id_mapping = cast( - Optional[Dict[str, Dict[str, str]]], - request_kwargs.get("model_file_id_mapping"), - ) - allowed_model_ids = [] - if input_file_id and model_file_id_mapping: - model_id_dict = model_file_id_mapping.get(input_file_id, {}) - allowed_model_ids = list(model_id_dict.keys()) - - if len(allowed_model_ids) == 0: - return healthy_deployments - - return [ - deployment - for deployment in healthy_deployments - if deployment.get("model_info", {}).get("id") in allowed_model_ids - ] - - async def async_pre_call_deployment_hook( - self, kwargs: Dict[str, Any], call_type: Optional[CallTypes] - ) -> Optional[dict]: - """ - Allow modifying the request just before it's sent to the deployment. - """ - accessor_key: Optional[str] = None - if call_type and call_type == CallTypes.acreate_batch: - accessor_key = "input_file_id" - elif call_type and call_type == CallTypes.acreate_fine_tuning_job: - accessor_key = "training_file" - else: - return kwargs - - if accessor_key: - input_file_id = cast(Optional[str], kwargs.get(accessor_key)) - model_file_id_mapping = cast( - Optional[Dict[str, Dict[str, str]]], kwargs.get("model_file_id_mapping") - ) - # model_info may be at top-level or nested under litellm_metadata - # (batch/file operations use litellm_metadata) - model_id = cast(Optional[str], kwargs.get("model_info", {}).get("id", None)) - if model_id is None: - model_id = cast( - Optional[str], - kwargs.get("litellm_metadata", {}) - .get("model_info", {}) - .get("id", None), - ) - mapped_file_id: Optional[str] = None - if input_file_id and model_file_id_mapping and model_id: - mapped_file_id = model_file_id_mapping.get(input_file_id, {}).get( - model_id, None - ) - if mapped_file_id: - kwargs[accessor_key] = mapped_file_id - - return kwargs - - def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]: - """ - Gets file ids from messages - """ - file_ids = [] - for message in messages: - if message.get("role") == "user": - content = message.get("content") - if content: - if isinstance(content, str): - continue - for c in content: - if c.get("type") == "file": - file_object = cast(ChatCompletionFileObject, c) - file_object_file_field = file_object["file"] - file_id = file_object_file_field.get("file_id") - if file_id: - file_ids.append(file_id) - return file_ids - - def get_file_ids_from_responses_input( - self, input: Union[str, List[Dict[str, Any]]] - ) -> List[str]: - """ - Gets file ids from responses API input. - - The input can be: - - A string (no files) - - A list of input items, where each item can have: - - type: "input_file" with file_id - - content: a list that can contain items with type: "input_file" and file_id - """ - file_ids: List[str] = [] - - if isinstance(input, str): - return file_ids - - if not isinstance(input, list): - return file_ids - - for item in input: - if not isinstance(item, dict): - continue - - # Check for direct input_file type - if item.get("type") == "input_file": - file_id = item.get("file_id") - if file_id: - file_ids.append(file_id) - - # Check for input_file in content array - content = item.get("content") - if isinstance(content, list): - for content_item in content: - if ( - isinstance(content_item, dict) - and content_item.get("type") == "input_file" - ): - file_id = content_item.get("file_id") - if file_id: - file_ids.append(file_id) - - return file_ids - - def get_file_ids_from_responses_tools( - self, tools: List[Dict[str, Any]] - ) -> List[str]: - """ - Gets file ids from responses API tools parameter. - - The tools can contain code_interpreter with container.file_ids: - [ - { - "type": "code_interpreter", - "container": {"type": "auto", "file_ids": ["file-123", "file-456"]} - } - ] - """ - file_ids: List[str] = [] - - if not isinstance(tools, list): - return file_ids - - for tool in tools: - if not isinstance(tool, dict): - continue - - # Check for code_interpreter with container file_ids - if tool.get("type") == "code_interpreter": - container = tool.get("container") - if isinstance(container, dict): - container_file_ids = container.get("file_ids") - if isinstance(container_file_ids, list): - for file_id in container_file_ids: - if isinstance(file_id, str): - file_ids.append(file_id) - - return file_ids - - def get_vector_store_ids_from_file_search_tools( - self, tools: List[Dict[str, Any]] - ) -> List[str]: - """ - Extract unified vector_store_ids from file_search tools. - - Only returns IDs that are LiteLLM-managed (base64 unified IDs). - Native provider IDs are skipped — they have no LiteLLM access record. - """ - from litellm.llms.base_llm.managed_resources.utils import ( - is_base64_encoded_unified_id, - ) - - vs_ids: List[str] = [] - if not isinstance(tools, list): - return vs_ids - - for tool in tools: - if not isinstance(tool, dict) or tool.get("type") != "file_search": - continue - vector_store_ids = tool.get("vector_store_ids") - if not isinstance(vector_store_ids, list): - continue - for vs_id in vector_store_ids: - if isinstance(vs_id, str) and is_base64_encoded_unified_id(vs_id): - vs_ids.append(vs_id) - - return vs_ids - - async def check_vector_store_ids_access( - self, - vector_store_ids: List[str], - user_api_key_dict: UserAPIKeyAuth, - ) -> None: - """ - Verify the caller's team can access each LiteLLM-managed vector store. - - Batch-fetches vector stores from DB and checks team_id. - Raises HTTPException(403) on the first access violation. - Non-managed (native) IDs should already be filtered out before calling this. - """ - from litellm.llms.base_llm.managed_resources.utils import ( - extract_unified_uuid_from_unified_id, - ) - from litellm.proxy.auth.auth_checks import ( - get_managed_vector_store_rows_by_uuids, - ) - from litellm.proxy.proxy_server import ( - prisma_client, - proxy_logging_obj, - user_api_key_cache, - ) - - if not vector_store_ids or prisma_client is None: - return - - # Map each unified ID to its internal UUID for a single batch DB fetch - uuid_to_unified: Dict[str, str] = {} - for vs_id in vector_store_ids: - uuid = extract_unified_uuid_from_unified_id(vs_id) - if uuid: - uuid_to_unified[uuid] = vs_id - - if not uuid_to_unified: - return - - rows = await get_managed_vector_store_rows_by_uuids( - uuids=list(uuid_to_unified.keys()), - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - proxy_logging_obj=proxy_logging_obj, - ) - - found_uuids = {row.vector_store_id for row in rows} - - for uuid, original_id in uuid_to_unified.items(): - if uuid not in found_uuids: - raise HTTPException( - status_code=403, - detail=f"Vector store '{original_id}' not found or access denied.", - ) - - caller_team_id = user_api_key_dict.team_id - for row in rows: - vs_team_id = getattr(row, "team_id", None) - if vs_team_id is not None and vs_team_id != caller_team_id: - raise HTTPException( - status_code=403, - detail=( - f"Team '{caller_team_id}' does not have access to vector " - f"store '{row.vector_store_id}'. The store belongs to team " - f"'{vs_team_id}'." - ), - ) - - async def get_model_file_id_mapping( - self, file_ids: List[str], litellm_parent_otel_span: Span - ) -> dict: - """ - Get model-specific file IDs for a list of proxy file IDs. - Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id - - 1. Get all the litellm_proxy/ file_ids from the messages - 2. For each file_id, search for cache keys matching the pattern file_id:* - 3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id - - Example: - { - "litellm_proxy/file_id": { - "model_id": "model_file_id" - } - } - """ - - file_id_mapping: Dict[str, Dict[str, str]] = {} - litellm_managed_file_ids = [] - - for file_id in file_ids: - ## CHECK IF FILE ID IS MANAGED BY LITELM - is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id) - if is_base64_unified_file_id: - litellm_managed_file_ids.append(file_id) - - if litellm_managed_file_ids: - # Get all cache keys matching the pattern file_id:* - for file_id in litellm_managed_file_ids: - # Search for any cache key starting with this file_id - unified_file_object = await self.get_unified_file_id( - file_id, litellm_parent_otel_span - ) - - if unified_file_object: - file_id_mapping[file_id] = unified_file_object.model_mappings - - return file_id_mapping - - async def create_file_for_each_model( - self, - llm_router: Optional[Router], - _create_file_request: CreateFileRequest, - target_model_names_list: List[str], - litellm_parent_otel_span: Span, - ) -> List[OpenAIFileObject]: - if llm_router is None: - raise Exception("LLM Router not initialized. Ensure models added to proxy.") - responses = [] - for model in target_model_names_list: - individual_response = await llm_router.acreate_file( - model=model, **_create_file_request - ) - responses.append(individual_response) - - return responses - - async def acreate_file( - self, - create_file_request: CreateFileRequest, - llm_router: Router, - target_model_names_list: List[str], - litellm_parent_otel_span: Span, - user_api_key_dict: UserAPIKeyAuth, - ) -> OpenAIFileObject: - responses = await self.create_file_for_each_model( - llm_router=llm_router, - _create_file_request=create_file_request, - target_model_names_list=target_model_names_list, - litellm_parent_otel_span=litellm_parent_otel_span, - ) - response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id( - file_objects=responses, - create_file_request=create_file_request, - internal_usage_cache=self.internal_usage_cache, - litellm_parent_otel_span=litellm_parent_otel_span, - target_model_names_list=target_model_names_list, - ) - - ## STORE MODEL MAPPINGS IN DB - model_mappings: Dict[str, str] = {} - - for file_object in responses: - model_file_id_mapping = file_object._hidden_params.get( - "model_file_id_mapping" - ) - if model_file_id_mapping and isinstance(model_file_id_mapping, dict): - model_mappings.update(model_file_id_mapping) - - await self.store_unified_file_id( - file_id=response.id, - file_object=response, - litellm_parent_otel_span=litellm_parent_otel_span, - model_mappings=model_mappings, - user_api_key_dict=user_api_key_dict, - ) - - # Emit Prometheus metrics for managed file creation - prom_logger = self._get_prometheus_logger() - if prom_logger: - first_model = ( - target_model_names_list[0] if target_model_names_list else None - ) - first_provider = "" - if responses: - first_provider = ( - getattr(responses[0], "_hidden_params", {}).get( - "custom_llm_provider" - ) - or "" - ) - prom_logger.record_managed_file_created( - model=first_model or "", - api_provider=first_provider, - user=user_api_key_dict.user_id or "", - user_email=getattr(user_api_key_dict, "user_email", None) or "", - api_key_alias=user_api_key_dict.key_alias or "", - ) - if response.bytes and response.bytes > 0: - prom_logger.record_managed_file_size( - size_bytes=response.bytes, - purpose=response.purpose or "batch", - file_type="input", - model=first_model, - api_provider=first_provider, - user=user_api_key_dict.user_id, - ) - - return response - - @staticmethod - async def return_unified_file_id( - file_objects: List[OpenAIFileObject], - create_file_request: CreateFileRequest, - internal_usage_cache: InternalUsageCache, - litellm_parent_otel_span: Span, - target_model_names_list: List[str], - ) -> OpenAIFileObject: - ## GET THE FILE TYPE FROM THE CREATE FILE REQUEST - file_data = extract_file_data(create_file_request["file"]) - - file_type = file_data["content_type"] - - output_file_id = file_objects[0].id - model_id = file_objects[0]._hidden_params.get("model_id") - - unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format( - file_type, - str(uuid.uuid4()), - ",".join(target_model_names_list), - output_file_id, - model_id, - ) - - # Convert to URL-safe base64 and strip padding - base64_unified_file_id = ( - base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=") - ) - - ## CREATE RESPONSE OBJECT - - response = OpenAIFileObject( - id=base64_unified_file_id, - object="file", - purpose=create_file_request["purpose"], - created_at=file_objects[0].created_at, - bytes=file_objects[0].bytes, - filename=file_objects[0].filename, - status="uploaded", - expires_at=file_objects[0].expires_at, - ) - - return response - - def get_unified_generic_response_id( - self, model_id: str, generic_response_id: str - ) -> str: - unified_generic_response_id = ( - SpecialEnums.LITELLM_MANAGED_GENERIC_RESPONSE_COMPLETE_STR.value.format( - model_id, generic_response_id - ) - ) - return ( - base64.urlsafe_b64encode(unified_generic_response_id.encode()) - .decode() - .rstrip("=") - ) - - def get_unified_batch_id(self, batch_id: str, model_id: str) -> str: - unified_batch_id = SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format( - model_id, batch_id - ) - return base64.urlsafe_b64encode(unified_batch_id.encode()).decode().rstrip("=") - - def get_unified_output_file_id( - self, output_file_id: str, model_id: str, model_name: Optional[str] - ) -> str: - unified_output_file_id = ( - SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format( - "application/json", - str(uuid.uuid4()), - model_name or "", - output_file_id, - model_id, - ) - ) - return ( - base64.urlsafe_b64encode(unified_output_file_id.encode()) - .decode() - .rstrip("=") - ) - - def get_model_id_from_unified_file_id(self, file_id: str) -> str: - return file_id.split("llm_output_file_model_id,")[1].split(";")[0] - - def get_output_file_id_from_unified_file_id(self, file_id: str) -> str: - return file_id.split("llm_output_file_id,")[1].split(";")[0] - - async def async_post_call_success_hook( - self, data: Dict, user_api_key_dict: UserAPIKeyAuth, response: LLMResponseTypes - ) -> Any: - if isinstance(response, LiteLLMBatch): - ## Check if unified_file_id is in the response - unified_file_id = response._hidden_params.get( - "unified_file_id" - ) # managed file id - unified_batch_id = response._hidden_params.get( - "unified_batch_id" - ) # managed batch id - model_id = cast(Optional[str], response._hidden_params.get("model_id")) - model_name = cast(Optional[str], response._hidden_params.get("model_name")) - resolved_model_name = model_name - - # Some providers (e.g. Vertex batch retrieve) do not set model_name on - # the response. In that case, recover target_model_names from the input - # managed file metadata so unified output IDs preserve routing metadata. - if not resolved_model_name and isinstance(unified_file_id, str): - decoded_unified_file_id = ( - _is_base64_encoded_unified_file_id(unified_file_id) - or unified_file_id - ) - target_model_names = get_models_from_unified_file_id( - decoded_unified_file_id - ) - if target_model_names: - resolved_model_name = ",".join(target_model_names) - original_response_id = response.id - - if (unified_batch_id or unified_file_id) and model_id: - response.id = self.get_unified_batch_id( - batch_id=response.id, model_id=model_id - ) - - # Handle both output_file_id and error_file_id - for file_attr in ["output_file_id", "error_file_id"]: - file_id_value = getattr(response, file_attr, None) - if file_id_value and model_id: - original_file_id = file_id_value - unified_file_id = self.get_unified_output_file_id( - output_file_id=original_file_id, - model_id=model_id, - model_name=resolved_model_name, - ) - setattr(response, file_attr, unified_file_id) - - # Use llm_router credentials when available. Without credentials, - # Azure and other auth-required providers return 500/401. - file_object = None - try: - # Import module and use getattr for better testability with mocks - import litellm.proxy.proxy_server as proxy_server_module - - _llm_router = getattr( - proxy_server_module, "llm_router", None - ) - if _llm_router is not None and model_id: - _creds = ( - _llm_router.get_deployment_credentials_with_provider( - model_id - ) - or {} - ) - file_object = await litellm.afile_retrieve( - file_id=original_file_id, - **_creds, - ) - else: - file_object = await litellm.afile_retrieve( - custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai", # type: ignore[arg-type] - file_id=original_file_id, - ) - verbose_logger.debug( - f"Successfully retrieved file object for {file_attr}={original_file_id}" - ) - except Exception as e: - verbose_logger.warning( - f"Failed to retrieve file object for {file_attr}={original_file_id}: {str(e)}. Storing with None and will fetch on-demand." - ) - - await self.store_unified_file_id( - file_id=unified_file_id, - file_object=file_object, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - model_mappings={model_id: original_file_id}, - user_api_key_dict=user_api_key_dict, - ) - await self.store_unified_object_id( - unified_object_id=response.id, - file_object=response, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - model_object_id=original_response_id, - file_purpose="batch", - user_api_key_dict=user_api_key_dict, - ) - - # Only record batch creation metric on actual create (not retrieve/cancel). - # unified_file_id in _hidden_params is only set by the create_batch endpoint. - original_unified_file_id = response._hidden_params.get("unified_file_id") - if original_unified_file_id: - prom_logger = self._get_prometheus_logger() - if prom_logger: - batch_provider = "" - if model_name: - try: - from litellm.litellm_core_utils.get_llm_provider_logic import ( - get_llm_provider, - ) - - _, batch_provider, _, _ = get_llm_provider(model=model_name) - except Exception: - if "/" in model_name: - batch_provider = model_name.split("/")[0] - prom_logger.record_managed_batch_created( - model=model_name or "", - api_provider=batch_provider, - user=user_api_key_dict.user_id or "", - user_email=getattr(user_api_key_dict, "user_email", None) or "", - api_key_alias=user_api_key_dict.key_alias or "", - ) - - elif isinstance(response, LiteLLMFineTuningJob): - ## Check if unified_file_id is in the response - unified_file_id = response._hidden_params.get( - "unified_file_id" - ) # managed file id - unified_finetuning_job_id = response._hidden_params.get( - "unified_finetuning_job_id" - ) # managed finetuning job id - model_id = cast(Optional[str], response._hidden_params.get("model_id")) - model_name = cast(Optional[str], response._hidden_params.get("model_name")) - original_response_id = response.id - if (unified_file_id or unified_finetuning_job_id) and model_id: - response.id = self.get_unified_generic_response_id( - model_id=model_id, generic_response_id=response.id - ) - await self.store_unified_object_id( - unified_object_id=response.id, - file_object=response, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - model_object_id=original_response_id, - file_purpose="fine-tune", - user_api_key_dict=user_api_key_dict, - ) - elif isinstance(response, AsyncCursorPage): - """ - For listing files, filter for the ones created by the user - """ - ## check if file object - if hasattr(response, "data") and isinstance(response.data, list): - if all( - isinstance(file_object, FileObject) for file_object in response.data - ): - ## Get all file id's - ## Check which file id's were created by the user - ## Filter the response to only include the files created by the user - ## Return the filtered response - file_ids = [ - file_object.id - for file_object in cast(List[FileObject], response.data) # type: ignore - ] - user_created_file_ids = await self.get_user_created_file_ids( - user_api_key_dict, file_ids - ) - ## Filter the response to only include the files created by the user - response.data = user_created_file_ids # type: ignore - return response - return response - return response - - async def afile_retrieve( - self, file_id: str, litellm_parent_otel_span: Optional[Span], llm_router=None - ) -> OpenAIFileObject: - stored_file_object = await self.get_unified_file_id( - file_id, litellm_parent_otel_span - ) - - # Case 1 : This is not a managed file - if not stored_file_object: - raise Exception(f"LiteLLM Managed File object with id={file_id} not found") - - # Case 2: Managed file and the file object exists in the database - # The stored file_object has the raw provider ID. Replace with the unified ID - # so callers see a consistent ID (matching Case 3 which does response.id = file_id). - if stored_file_object and stored_file_object.file_object: - # Use model_copy to ensure the ID update persists (Pydantic v2 compatibility) - response = stored_file_object.file_object.model_copy(update={"id": file_id}) - return response - - # Case 3: Managed file exists in the database but not the file object (for. e.g the batch task might not have run) - # So we fetch the file object from the provider. We deliberately do not store the result to avoid interfering with batch cost tracking code. - if not llm_router: - raise Exception( - f"LiteLLM Managed File object with id={file_id} has no file_object " - f"and llm_router is required to fetch from provider" - ) - - try: - model_id, model_file_id = next( - iter(stored_file_object.model_mappings.items()) - ) - credentials = ( - llm_router.get_deployment_credentials_with_provider(model_id) or {} - ) - response = await litellm.afile_retrieve( - file_id=model_file_id, **credentials - ) - response.id = file_id # Replace with unified ID - return response - except Exception as e: - raise Exception( - f"Failed to retrieve file {file_id} from provider: {str(e)}" - ) from e - - async def afile_list( - self, - purpose: Optional[OpenAIFilesPurpose], - litellm_parent_otel_span: Optional[Span], - **data: Dict, - ) -> List[OpenAIFileObject]: - """Handled in files_endpoints.py""" - return [] - - def _is_batch_polling_enabled(self) -> bool: - """ - Check if batch cost tracking is actually enabled and running. - Returns: - bool: True if batch cost tracking is active, False otherwise - """ - try: - # Import here to avoid circular dependencies - import litellm.proxy.proxy_server as proxy_server_module - - # Check if the scheduler has the batch cost checking job registered - scheduler = getattr(proxy_server_module, "scheduler", None) - if scheduler is None: - return False - - # Check if the check_batch_cost_job exists in the scheduler - try: - job = scheduler.get_job("check_batch_cost_job") - if job is not None: - return True - except Exception: - # Job not found or scheduler doesn't support get_job - pass - - return False - except Exception as e: - verbose_logger.warning( - f"Error checking batch polling configuration: {e}. Assuming disabled." - ) - return False - - async def _get_batches_referencing_file(self, file_id: str) -> List[Dict[str, Any]]: - """ - Find batches that reference this file and still need cost tracking. - Find batches that are in non-terminal state and have not yet been processed by CheckBatchCost. - Args: - file_id: The unified file ID to check - - Returns: - List of batch objects referencing this file in non-terminal state - (max 10 for error message display) - """ - # Prepare list of file IDs to check (both unified and provider IDs) - file_ids_to_check = [file_id] - - # Get model-specific file IDs for this unified file ID if it's a managed file - try: - model_file_id_mapping = await self.get_model_file_id_mapping( - [file_id], litellm_parent_otel_span=None - ) - - if model_file_id_mapping and file_id in model_file_id_mapping: - # Add all provider file IDs for this unified file - provider_file_ids = list(model_file_id_mapping[file_id].values()) - file_ids_to_check.extend(provider_file_ids) - except Exception as e: - verbose_logger.debug( - f"Could not get model file ID mapping for {file_id}: {e}. " - f"Will only check unified file ID." - ) - MAX_MATCHES_TO_RETURN = 10 - - batches = await self.prisma_client.db.litellm_managedobjecttable.find_many( - where={ - "file_purpose": "batch", - "batch_processed": False, - "status": {"not_in": ["failed", "expired", "cancelled"]}, - }, - take=MAX_MATCHES_TO_RETURN, - order={"created_at": "desc"}, - ) - - referencing_batches = [] - for batch in batches: - try: - # Parse the batch file_object to check for file references - batch_data = ( - json.loads(batch.file_object) - if isinstance(batch.file_object, str) - else batch.file_object - ) - - # Extract file IDs from batch - # Batches typically reference the unified file ID in input_file_id - # Output and error files are generated by the provider - input_file_id = batch_data.get("input_file_id") - output_file_id = batch_data.get("output_file_id") - error_file_id = batch_data.get("error_file_id") - - referenced_file_ids = [ - fid for fid in [input_file_id, output_file_id, error_file_id] if fid - ] - - # Check if any referenced file ID matches the file we're trying to delete - if any(ref_id in file_ids_to_check for ref_id in referenced_file_ids): - referencing_batches.append( - { - "batch_id": batch.unified_object_id, - "status": batch.status, - "created_at": batch.created_at, - } - ) - except Exception as e: - verbose_logger.warning( - f"Error parsing batch object {batch.unified_object_id}: {e}" - ) - continue - - return referencing_batches - - async def _check_file_deletion_allowed(self, file_id: str) -> None: - """ - Check if file deletion should be blocked due to batch references. - - Blocks deletion if: - 1. File is referenced by any batch in non-terminal state, AND - 2. Batch polling is configured (user wants cost tracking) - - Args: - file_id: The unified file ID to check - - Raises: - HTTPException: If file deletion should be blocked - """ - # Check if batch polling is enabled - if not self._is_batch_polling_enabled(): - # Batch polling not configured, allow deletion - return - - # Check if file is referenced by any non-terminal batches - referencing_batches = await self._get_batches_referencing_file(file_id) - - if referencing_batches: - # File is referenced by non-terminal batches and polling is enabled - MAX_BATCHES_IN_ERROR = ( - 5 # Limit batches shown in error message for readability - ) - - # Show up to MAX_BATCHES_IN_ERROR in the error message - batches_to_show = referencing_batches[:MAX_BATCHES_IN_ERROR] - batch_statuses = [ - f"{b['batch_id']}: {b['status']}" for b in batches_to_show - ] - - # Determine the count message - count_message = f"{len(referencing_batches)}" - if ( - len(referencing_batches) >= 10 - ): # MAX_MATCHES_TO_RETURN from _get_batches_referencing_file - count_message = "10+" - - error_message = ( - f"Cannot delete file {file_id}. " - f"The file is referenced by {count_message} batch(es) in non-terminal state" - ) - - # Add specific batch details if not too many - if len(referencing_batches) <= MAX_BATCHES_IN_ERROR: - error_message += f": {', '.join(batch_statuses)}. " - else: - error_message += f" (showing {MAX_BATCHES_IN_ERROR} most recent): {', '.join(batch_statuses)}. " - - error_message += ( - f"To delete this file before complete cost tracking, please delete or cancel the referencing batch(es) first. " - f"Alternatively, wait for all batches to complete and for cost to be computed (batch_processed=true)." - ) - - # Record blocked deletion metric - prom_logger = self._get_prometheus_logger() - if prom_logger: - prom_logger.record_managed_file_deleted(result="blocked") - - raise HTTPException( - status_code=400, - detail=error_message, - ) - - async def afile_delete( - self, - file_id: str, - litellm_parent_otel_span: Optional[Span], - llm_router: Router, - **data: Dict, - ) -> OpenAIFileObject: - - # Check if file deletion should be blocked due to batch references - await self._check_file_deletion_allowed(file_id) - - # file_id = convert_b64_uid_to_unified_uid(file_id) - model_file_id_mapping = await self.get_model_file_id_mapping( - [file_id], litellm_parent_otel_span - ) - - delete_response = None - specific_model_file_id_mapping = model_file_id_mapping.get(file_id) - if specific_model_file_id_mapping: - # Remove conflicting keys from data to avoid duplicate keyword arguments - filtered_data = { - k: v for k, v in data.items() if k not in ("model", "file_id") - } - for model_id, model_file_id in specific_model_file_id_mapping.items(): - delete_response = await llm_router.afile_delete(model=model_id, file_id=model_file_id, **filtered_data) # type: ignore - - stored_file_object = await self.delete_unified_file_id( - file_id, litellm_parent_otel_span - ) - - # Record successful deletion metric only on actual success - if stored_file_object or delete_response: - prom_logger = self._get_prometheus_logger() - if prom_logger: - prom_logger.record_managed_file_deleted(result="success") - - if stored_file_object: - return stored_file_object - elif delete_response: - delete_response.id = file_id - return delete_response - else: - raise Exception(f"LiteLLM Managed File object with id={file_id} not found") - - async def afile_content( - self, - file_id: str, - litellm_parent_otel_span: Optional[Span], - llm_router: Router, - **data: Dict, - ) -> "HttpxBinaryResponseContent": - """ - Get the content of a file from first model that has it - """ - model_file_id_mapping = data.pop("model_file_id_mapping", None) - model_file_id_mapping = ( - model_file_id_mapping - or await self.get_model_file_id_mapping([file_id], litellm_parent_otel_span) - ) - - specific_model_file_id_mapping = model_file_id_mapping.get(file_id) - - if specific_model_file_id_mapping: - exception_dict = {} - for model_id, file_id in specific_model_file_id_mapping.items(): - try: - return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore - except Exception as e: - exception_dict[model_id] = str(e) - raise Exception( - f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}" - ) - else: - raise Exception(f"LiteLLM Managed File object with id={file_id} not found") - - async def _convert_storage_files_to_base64( - self, - messages: List[AllMessageValues], - file_ids: List[str], - litellm_parent_otel_span: Optional[Span], - ) -> None: - """ - Convert files stored in storage backends to base64 format for Vertex AI/Gemini. - - This method checks if any managed files are stored in storage backends, - downloads them, and converts them to base64 format in the messages. - """ - # Check each file_id to see if it's stored in a storage backend - for file_id in file_ids: - # Check if this is a base64 encoded unified file ID - decoded_unified_file_id = _is_base64_encoded_unified_file_id(file_id) - - if not decoded_unified_file_id: - continue - - # Check database for storage backend info - # IMPORTANT: The database stores the base64 encoded unified_file_id (not the decoded version) - # So we query with the original file_id (which is base64 encoded) - db_file = await self.prisma_client.db.litellm_managedfiletable.find_first( - where={"unified_file_id": file_id} - ) - - if not db_file or not db_file.storage_backend or not db_file.storage_url: - continue - - # File is stored in a storage backend, download and convert to base64 - try: - from litellm.llms.base_llm.files.storage_backend_factory import ( - get_storage_backend, - ) - - storage_backend_name = db_file.storage_backend - storage_url = db_file.storage_url - - # Get storage backend (uses same env vars as callback) - try: - storage_backend = get_storage_backend(storage_backend_name) - except ValueError as e: - verbose_logger.warning( - f"Storage backend '{storage_backend_name}' error for file {file_id}: {str(e)}" - ) - continue - - file_content = await storage_backend.download_file(storage_url) - - # Determine content type from file object - content_type = self._get_content_type_from_file_object( - db_file.file_object - ) - - # Convert to base64 - base64_data = base64.b64encode(file_content).decode("utf-8") - base64_data_uri = f"data:{content_type};base64,{base64_data}" - - # Update messages to use base64 instead of file_id - self._update_messages_with_base64_data( - messages, file_id, base64_data_uri, content_type - ) - except Exception as e: - verbose_logger.exception( - f"Error converting file {file_id} from storage backend to base64: {str(e)}" - ) - # Continue with other files even if one fails - continue - - def _get_content_type_from_file_object(self, file_object: Optional[Any]) -> str: - """ - Determine content type from file object. - - Uses the MIME type utility for consistent detection and normalization. - - Args: - file_object: The file object from the database (can be dict, JSON string, or None) - - Returns: - str: MIME type (defaults to "application/octet-stream" if cannot be determined) - """ - # Use utility function for detection - content_type = get_content_type_from_file_object(file_object) - - # Normalize for Gemini/Vertex AI (requires image/jpeg, not image/jpg) - content_type = normalize_mime_type_for_provider(content_type, provider="gemini") - - return content_type - - def _update_messages_with_base64_data( - self, - messages: List[AllMessageValues], - file_id: str, - base64_data_uri: str, - content_type: str, - ) -> None: - """ - Update messages to replace file_id with base64 data URI. - - Args: - messages: List of messages to update - file_id: The file ID to replace - base64_data_uri: The base64 data URI to use as replacement - content_type: The MIME type of the file (e.g., "image/jpeg", "application/pdf") - """ - for message in messages: - if message.get("role") == "user": - content = message.get("content") - if content and isinstance(content, list): - for element in content: - if element.get("type") == "file": - file_element = cast(ChatCompletionFileObject, element) - file_element_file = file_element.get("file", {}) - - if file_element_file.get("file_id") == file_id: - # Replace file_id with base64 data - file_element_file["file_data"] = base64_data_uri - # Set format to help Gemini determine mime type - file_element_file["format"] = content_type - # Remove file_id to ensure only file_data is used - file_element_file.pop("file_id", None) - - verbose_logger.debug( - f"Converted file {file_id} from storage backend to base64 with format {content_type}" - ) diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_vector_stores.py b/enterprise/litellm_enterprise/proxy/hooks/managed_vector_stores.py deleted file mode 100644 index 254d816039..0000000000 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_vector_stores.py +++ /dev/null @@ -1,464 +0,0 @@ -# What is this? -## This hook is used to manage vector stores with target_model_names support -## It allows creating vector stores across multiple models and managing them with unified IDs - -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast - -from fastapi import HTTPException - -import litellm -from litellm import Router, verbose_logger -from litellm._uuid import uuid -from litellm.integrations.custom_logger import CustomLogger -from litellm.llms.base_llm.managed_resources import BaseManagedResource -from litellm.llms.base_llm.managed_resources.utils import ( - generate_unified_id_string, - is_base64_encoded_unified_id, -) -from litellm.proxy._types import UserAPIKeyAuth -from litellm.types.vector_stores import ( - VectorStoreCreateOptionalRequestParams, - VectorStoreCreateResponse, -) - -if TYPE_CHECKING: - from opentelemetry.trace import Span as _Span - - from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache - from litellm.proxy.utils import PrismaClient as _PrismaClient - - Span = Union[_Span, Any] - InternalUsageCache = _InternalUsageCache - PrismaClient = _PrismaClient -else: - Span = Any - InternalUsageCache = Any - PrismaClient = Any - - -class _PROXY_LiteLLMManagedVectorStores( - CustomLogger, BaseManagedResource[VectorStoreCreateResponse] -): - """ - Managed vector stores with target_model_names support. - - This class provides functionality to: - - Create vector stores across multiple models - - Retrieve vector stores by unified ID - - Delete vector stores from all models - - List vector stores created by a user - """ - - def __init__( - self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient - ): - CustomLogger.__init__(self) - BaseManagedResource.__init__(self, internal_usage_cache, prisma_client) - - # ============================================================================ - # ABSTRACT METHOD IMPLEMENTATIONS - # ============================================================================ - - @property - def resource_type(self) -> str: - """Return the resource type identifier.""" - return "vector_store" - - @property - def table_name(self) -> str: - """Return the database table name for vector stores.""" - # Prisma converts model name LiteLLM_ManagedVectorStoreTable to litellm_managedvectorstoretable - return "litellm_managedvectorstoretable" - - def get_unified_resource_id_format( - self, - resource_object: VectorStoreCreateResponse, - target_model_names_list: List[str], - ) -> str: - """ - Generate the format string for the unified vector store ID. - - Format: - litellm_proxy:vector_store;unified_id,;target_model_names,;resource_id,;model_id, - """ - # VectorStoreCreateResponse is a TypedDict, so resource_object is a dictionary - # Extract provider resource ID from the response - provider_resource_id = resource_object.get("id", "") - - # Model ID is stored in hidden params if the response object supports it - # For TypedDict responses, we need to check if _hidden_params was added - hidden_params: Dict[str, Any] = {} - if hasattr(resource_object, "_hidden_params"): - hidden_params = getattr(resource_object, "_hidden_params", {}) or {} - model_id = hidden_params.get("model_id", "") - - return generate_unified_id_string( - resource_type=self.resource_type, - unified_uuid=str(uuid.uuid4()), - target_model_names=target_model_names_list, - provider_resource_id=provider_resource_id, - model_id=model_id, - ) - - async def create_resource_for_model( - self, - llm_router: Router, - model: str, - request_data: Dict[str, Any], - litellm_parent_otel_span: Span, - ) -> VectorStoreCreateResponse: - """ - Create a vector store for a specific model. - - Args: - llm_router: LiteLLM router instance - model: Model name to create vector store for - request_data: Request data for vector store creation - litellm_parent_otel_span: OpenTelemetry span for tracing - - Returns: - VectorStoreCreateResponse from the provider - """ - # Use the router to create the vector store - response = await llm_router.avector_store_create( - model=model, **request_data - ) - return response - - # ============================================================================ - # VECTOR STORE CRUD OPERATIONS - # ============================================================================ - - async def acreate_vector_store( - self, - create_request: VectorStoreCreateOptionalRequestParams, - llm_router: Router, - target_model_names_list: List[str], - litellm_parent_otel_span: Span, - user_api_key_dict: UserAPIKeyAuth, - ) -> VectorStoreCreateResponse: - """ - Create a vector store across multiple models. - - Args: - create_request: Vector store creation request parameters - llm_router: LiteLLM router instance - target_model_names_list: List of target model names - litellm_parent_otel_span: OpenTelemetry span for tracing - user_api_key_dict: User API key authentication details - - Returns: - VectorStoreCreateResponse with unified ID - """ - verbose_logger.info( - f"Creating managed vector store for models: {target_model_names_list}" - ) - - # Create vector store for each model - # Convert TypedDict to Dict[str, Any] for base class compatibility - request_data_dict: Dict[str, Any] = dict(create_request) - responses = await self.create_resource_for_each_model( - llm_router=llm_router, - request_data=request_data_dict, - target_model_names_list=target_model_names_list, - litellm_parent_otel_span=litellm_parent_otel_span, - ) - - # Generate unified ID - unified_id = self.generate_unified_resource_id( - resource_objects=responses, - target_model_names_list=target_model_names_list, - ) - - # Extract model mappings from responses - model_mappings: Dict[str, str] = {} - for response in responses: - hidden_params = getattr(response, "_hidden_params", {}) or {} - model_id = hidden_params.get("model_id") - if model_id: - # VectorStoreCreateResponse is a TypedDict, use dict access - model_mappings[model_id] = response["id"] - - verbose_logger.debug( - f"Created vector stores with model mappings: {model_mappings}" - ) - - # Store in database - await self.store_unified_resource_id( - unified_resource_id=unified_id, - resource_object=responses[0], # Store first response as template - litellm_parent_otel_span=litellm_parent_otel_span, - model_mappings=model_mappings, - user_api_key_dict=user_api_key_dict, - ) - - # Return response with unified ID - # VectorStoreCreateResponse is a TypedDict, so we need to create a new dict with the unified ID - response = responses[0].copy() - response["id"] = unified_id - - verbose_logger.info( - f"Successfully created managed vector store with unified ID: {unified_id}" - ) - - return response - - async def alist_vector_stores( - self, - user_api_key_dict: UserAPIKeyAuth, - limit: Optional[int] = None, - after: Optional[str] = None, - order: Optional[str] = None, - ) -> Dict[str, Any]: - """ - List vector stores created by a user. - - Args: - user_api_key_dict: User API key authentication details - limit: Maximum number of vector stores to return - after: Cursor for pagination - order: Sort order ('asc' or 'desc') - - Returns: - Dictionary with list of vector stores and pagination info - """ - # Use the base class method - return await self.list_user_resources( - user_api_key_dict=user_api_key_dict, - limit=limit, - after=after, - ) - - # ============================================================================ - # ACCESS CONTROL - # ============================================================================ - - async def check_vector_store_access( - self, vector_store_id: str, user_api_key_dict: UserAPIKeyAuth - ) -> bool: - """ - Check if user has access to a vector store. - - Args: - vector_store_id: The unified vector store ID - user_api_key_dict: User API key authentication details - - Returns: - True if user has access, False otherwise - """ - is_unified_id = is_base64_encoded_unified_id(vector_store_id) - - if is_unified_id: - # Check access for managed vector store - return await self.can_user_access_unified_resource_id( - vector_store_id, - user_api_key_dict, - ) - - # Not a managed vector store, allow access - return True - - async def check_managed_vector_store_access( - self, data: Dict, user_api_key_dict: UserAPIKeyAuth - ) -> bool: - """ - Check if user has access to a managed vector store in request data. - - Args: - data: Request data containing vector_store_id - user_api_key_dict: User API key authentication details - - Returns: - True if this is a managed vector store and user has access - - Raises: - HTTPException: If user doesn't have access - """ - vector_store_id = cast(Optional[str], data.get("vector_store_id")) - is_unified_id = ( - is_base64_encoded_unified_id(vector_store_id) - if vector_store_id - else False - ) - - if is_unified_id and vector_store_id: - if await self.can_user_access_unified_resource_id( - vector_store_id, user_api_key_dict - ): - return True - else: - raise HTTPException( - status_code=403, - detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}", - ) - - return False - - # ============================================================================ - # PRE-CALL HOOK (For Router Integration) - # ============================================================================ - - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache: Any, - data: Dict, - call_type: str, - ) -> Union[Exception, str, Dict, None]: - """ - Pre-call hook to handle vector store operations. - - This hook intercepts vector store requests and: - - Validates access for managed vector stores - - Transforms unified IDs to provider-specific IDs - - Adds model routing information - - Args: - user_api_key_dict: User API key authentication details - cache: Cache instance - data: Request data - call_type: Type of call being made - - Returns: - Modified request data or None - """ - from litellm.llms.base_llm.managed_resources.utils import ( - is_base64_encoded_unified_id, - parse_unified_id, - ) - - # Handle vector store search operations - if call_type == "avector_store_search": - vector_store_id = data.get("vector_store_id") - - if vector_store_id: - # Check if it's a managed vector store ID - decoded_id = is_base64_encoded_unified_id(vector_store_id) - - if decoded_id: - verbose_logger.debug( - f"Processing managed vector store search: {vector_store_id}" - ) - - # Check access - has_access = await self.can_user_access_unified_resource_id( - vector_store_id, user_api_key_dict - ) - - if not has_access: - raise HTTPException( - status_code=403, - detail=f"User {user_api_key_dict.user_id} does not have access to vector store {vector_store_id}", - ) - - # Parse the unified ID to extract components - parsed_id = parse_unified_id(vector_store_id) - - if parsed_id: - # Extract the model ID and provider resource ID - model_id = parsed_id.get("model_id") - provider_resource_id = parsed_id.get("provider_resource_id") - target_model_names = parsed_id.get("target_model_names", []) - - verbose_logger.debug( - f"Decoded vector store - model_id: {model_id}, provider_resource_id: {provider_resource_id}, target_model_names: {target_model_names}" - ) - - # Determine which model to use for routing - # Priority: model_id (deployment ID) > first target_model_name - routing_model = None - if model_id: - routing_model = model_id - elif target_model_names and len(target_model_names) > 0: - routing_model = target_model_names[0] - - # Set the model for routing - if routing_model: - data["model"] = routing_model - verbose_logger.info( - f"Routing vector store search to model: {routing_model}" - ) - - # Replace the unified ID with the provider-specific ID - if provider_resource_id: - data["vector_store_id"] = provider_resource_id - verbose_logger.debug( - f"Replaced unified ID with provider resource ID: {provider_resource_id}" - ) - - # Handle vector store retrieve/delete operations - elif call_type in ("avector_store_retrieve", "avector_store_delete"): - await self.check_managed_vector_store_access(data, user_api_key_dict) - - # If it's a managed vector store, we'll handle it in the endpoint - # No need to transform here as the endpoint will route to the hook - - return data - - # ============================================================================ - # POST-CALL HOOK (For Response Transformation) - # ============================================================================ - - async def async_post_call_success_hook( - self, - data: Dict, - user_api_key_dict: UserAPIKeyAuth, - response: Any, - ) -> Any: - """ - Post-call hook to transform responses. - - This hook can be used to transform responses if needed. - For now, it just passes through the response. - - Args: - data: Request data - user_api_key_dict: User API key authentication details - response: Response from the provider - - Returns: - Potentially modified response - """ - # Currently no transformation needed - return response - - # ============================================================================ - # DEPLOYMENT FILTERING - # ============================================================================ - - async def async_filter_deployments( # type: ignore[override] - self, - model: str, - healthy_deployments: List, - messages: Optional[List] = None, - request_kwargs: Optional[Dict] = None, - parent_otel_span: Optional[Span] = None, - ) -> List[Dict]: - """ - Filter deployments based on vector store availability. - - This is used by the router to select only deployments that have - the vector store available. - - Note: This method signature is a compromise between CustomLogger and BaseManagedResource - parent classes which have incompatible signatures. The type: ignore[override] is necessary - due to this multiple inheritance conflict. - - Args: - model: Model name - healthy_deployments: List of healthy deployments - messages: Messages (unused for vector stores, required by CustomLogger interface) - request_kwargs: Request kwargs containing vector_store_id and mappings - parent_otel_span: OpenTelemetry span for tracing - - Returns: - Filtered list of deployments - """ - return await BaseManagedResource.async_filter_deployments( - self, - model=model, - healthy_deployments=healthy_deployments, - request_kwargs=request_kwargs, - parent_otel_span=parent_otel_span, - resource_id_key="vector_store_id", - ) diff --git a/enterprise/litellm_enterprise/proxy/management_endpoints/__init__.py b/enterprise/litellm_enterprise/proxy/management_endpoints/__init__.py deleted file mode 100644 index 0791a061b4..0000000000 --- a/enterprise/litellm_enterprise/proxy/management_endpoints/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from fastapi import APIRouter - -from .internal_user_endpoints import router as internal_user_endpoints_router -from .project_endpoints import router as project_endpoints_router - -management_endpoints_router = APIRouter() -management_endpoints_router.include_router(internal_user_endpoints_router) -management_endpoints_router.include_router(project_endpoints_router) - -__all__ = ["management_endpoints_router"] diff --git a/enterprise/litellm_enterprise/proxy/management_endpoints/internal_user_endpoints.py b/enterprise/litellm_enterprise/proxy/management_endpoints/internal_user_endpoints.py deleted file mode 100644 index 2f53f9e928..0000000000 --- a/enterprise/litellm_enterprise/proxy/management_endpoints/internal_user_endpoints.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Enterprise internal user management endpoints -""" - - -from fastapi import APIRouter, Depends, HTTPException - -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.management_endpoints.internal_user_endpoints import user_api_key_auth - -router = APIRouter() - - -@router.get( - "/user/available_users", - tags=["Internal User management"], - dependencies=[Depends(user_api_key_auth)], -) -async def available_enterprise_users( - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - For keys with `max_users` set, return the list of users that are allowed to use the key. - """ - from litellm.proxy._types import CommonProxyErrors, EnterpriseLicenseData - from litellm.proxy.proxy_server import ( - premium_user, - premium_user_data, - prisma_client, - ) - - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - if not premium_user: - # check if SSO is enabled - show 5 user limit - from litellm.proxy.auth.auth_utils import _has_user_setup_sso - - if _has_user_setup_sso(): - premium_user_data = EnterpriseLicenseData( - max_users=5, - ) - - # Count number of rows in LiteLLM_UserTable - user_count = await prisma_client.db.litellm_usertable.count() - team_count = await prisma_client.db.litellm_teamtable.count() - - if ( - not premium_user_data - or premium_user_data is not None - and "max_users" not in premium_user_data - ): - max_users = None - else: - max_users = premium_user_data.get("max_users") - - if premium_user_data and "max_teams" in premium_user_data: - max_teams = premium_user_data.get("max_teams") - else: - max_teams = None - - return { - "total_users": max_users, - "total_teams": max_teams, - "total_users_used": user_count, - "total_teams_used": team_count, - "total_teams_remaining": (max_teams - team_count if max_teams else None), - "total_users_remaining": (max_users - user_count if max_users else None), - } diff --git a/enterprise/litellm_enterprise/proxy/management_endpoints/key_management_endpoints.py b/enterprise/litellm_enterprise/proxy/management_endpoints/key_management_endpoints.py deleted file mode 100644 index 794568b210..0000000000 --- a/enterprise/litellm_enterprise/proxy/management_endpoints/key_management_endpoints.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Optional - -from litellm.proxy._types import GenerateKeyRequest, LiteLLM_TeamTable - - -def add_team_member_key_duration( - team_table: Optional[LiteLLM_TeamTable], - data: GenerateKeyRequest, -) -> GenerateKeyRequest: - if team_table is None: - return data - - if data.user_id is None: # only apply for team member keys, not service accounts - return data - - if ( - team_table.metadata is not None - and team_table.metadata.get("team_member_key_duration") is not None - ): - data.duration = team_table.metadata["team_member_key_duration"] - - return data - - -def add_team_organization_id( - team_table: Optional[LiteLLM_TeamTable], - data: GenerateKeyRequest, -) -> GenerateKeyRequest: - if team_table is None: - return data - setattr(data, "organization_id", team_table.organization_id) - return data - - -def apply_enterprise_key_management_params( - data: GenerateKeyRequest, - team_table: Optional[LiteLLM_TeamTable], -) -> GenerateKeyRequest: - - data = add_team_member_key_duration(team_table, data) - data = add_team_organization_id(team_table, data) - return data diff --git a/enterprise/litellm_enterprise/proxy/management_endpoints/project_endpoints.py b/enterprise/litellm_enterprise/proxy/management_endpoints/project_endpoints.py deleted file mode 100644 index 75229bacc8..0000000000 --- a/enterprise/litellm_enterprise/proxy/management_endpoints/project_endpoints.py +++ /dev/null @@ -1,963 +0,0 @@ -""" -Endpoints for /project operations - -/project/new -/project/update -/project/delete -/project/info -/project/list -""" - -#### PROJECT MANAGEMENT #### - -import json -from typing import List, Optional, Union - -from fastapi import APIRouter, Depends, HTTPException, Request - -from litellm._logging import verbose_proxy_logger -from litellm._uuid import uuid -from litellm.proxy._types import * -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.proxy.management_endpoints.common_utils import _set_object_metadata_field -from litellm.proxy.management_helpers.utils import ( - management_endpoint_wrapper, -) -from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy - -router = APIRouter() - - -async def _check_user_permission_for_project( - user_api_key_dict: UserAPIKeyAuth, - team_id: Optional[str], - prisma_client: PrismaClient, - require_admin: bool = False, - team_object: Optional[LiteLLM_TeamTable] = None, -) -> bool: - """ - Check if user has permission to manage a project. - - Returns True if user is proxy admin or team admin (when team_id provided). - If require_admin=True, only proxy admins are allowed. - - If team_object is provided, it will be used instead of fetching from DB - (avoids duplicate DB queries when team was already fetched for validation). - """ - is_proxy_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN - - if require_admin: - return is_proxy_admin - - if is_proxy_admin: - return True - - if not team_id or not user_api_key_dict.user_id: - return False - - team = team_object - if team is None: - team = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id} - ) - - if team and team.admins: - return user_api_key_dict.user_id in team.admins - - return False - - -async def _validate_team_exists( - team_id: str, - prisma_client: PrismaClient, -): - """Validate that a team exists. Returns the team row.""" - team = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id}, - ) - - if team is None: - raise ProxyException( - message=f"Team not found, team_id={team_id}", - type="not_found", - code=404, - param="team_id", - ) - - return team - - -def _check_team_project_limits( - team_object: LiteLLM_TeamTable, - data: Union[NewProjectRequest, UpdateProjectRequest], -) -> None: - """ - Check that project limits respect its parent Team's limits. - - Mirrors _check_org_team_limits() from team_endpoints.py. - - Validates: - - Project models are a subset of Team models - - Project max_budget <= Team max_budget - - Project tpm_limit <= Team tpm_limit - - Project rpm_limit <= Team rpm_limit - - Budget values are non-negative - - soft_budget < max_budget - """ - # --- Budget non-negativity checks --- - if data.max_budget is not None and data.max_budget < 0: - raise HTTPException( - status_code=400, - detail={ - "error": f"max_budget cannot be negative. Received: {data.max_budget}" - }, - ) - if data.soft_budget is not None and data.soft_budget < 0: - raise HTTPException( - status_code=400, - detail={ - "error": f"soft_budget cannot be negative. Received: {data.soft_budget}" - }, - ) - - # --- soft_budget < max_budget --- - if data.soft_budget is not None and data.max_budget is not None: - if data.soft_budget >= data.max_budget: - raise HTTPException( - status_code=400, - detail={ - "error": f"soft_budget ({data.soft_budget}) must be strictly lower than max_budget ({data.max_budget})" - }, - ) - - # --- Validate project models are a subset of team models --- - project_models = getattr(data, "models", None) - team_models = team_object.models or [] - if project_models and len(team_models) > 0: - # If team has 'all-proxy-models', skip validation as it allows all models - if SpecialModelNames.all_proxy_models.value not in team_models: - for m in project_models: - if m not in team_models: - raise HTTPException( - status_code=400, - detail={ - "error": f"Model '{m}' not in team's allowed models. Team allowed models={team_models}. Team: {team_object.team_id}" - }, - ) - - # --- Validate project max_budget <= team max_budget --- - # Team stores budget fields directly (max_budget, tpm_limit, rpm_limit) - # unlike Project which uses a separate LiteLLM_BudgetTable relation - if ( - data.max_budget is not None - and team_object.max_budget is not None - and data.max_budget > team_object.max_budget - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"Project max_budget ({data.max_budget}) exceeds team's max_budget ({team_object.max_budget}). Team: {team_object.team_id}" - }, - ) - - # --- Validate project tpm_limit <= team tpm_limit --- - if ( - data.tpm_limit is not None - and team_object.tpm_limit is not None - and data.tpm_limit > team_object.tpm_limit - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"Project tpm_limit ({data.tpm_limit}) exceeds team's tpm_limit ({team_object.tpm_limit}). Team: {team_object.team_id}" - }, - ) - - # --- Validate project rpm_limit <= team rpm_limit --- - if ( - data.rpm_limit is not None - and team_object.rpm_limit is not None - and data.rpm_limit > team_object.rpm_limit - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"Project rpm_limit ({data.rpm_limit}) exceeds team's rpm_limit ({team_object.rpm_limit}). Team: {team_object.team_id}" - }, - ) - - -async def _create_budget_for_project( - data: NewProjectRequest, - user_id: Optional[str], - litellm_proxy_admin_name: str, - prisma_client: PrismaClient, -) -> str: - """Create a budget for the project and return budget_id.""" - budget_params = LiteLLM_BudgetTable.model_fields.keys() - _json_data = data.json(exclude_none=True) - _budget_data = {k: v for k, v in _json_data.items() if k in budget_params} - budget_row = LiteLLM_BudgetTable(**_budget_data) - - new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) - - _budget = await prisma_client.db.litellm_budgettable.create( - data={ - **new_budget, - "created_by": user_id or litellm_proxy_admin_name, - "updated_by": user_id or litellm_proxy_admin_name, - } - ) - - return _budget.budget_id - - -async def _set_project_object_permission( - data: NewProjectRequest, - prisma_client: Optional[PrismaClient], -) -> Optional[str]: - """ - Creates the LiteLLM_ObjectPermissionTable record for the project. - Returns the object_permission_id if created, otherwise None. - """ - if prisma_client is None: - return None - - if data.object_permission is not None: - created_object_permission = ( - await prisma_client.db.litellm_objectpermissiontable.create( - data=data.object_permission.model_dump(exclude_none=True), - ) - ) - del data.object_permission - return created_object_permission.object_permission_id - return None - - -def _remove_budget_fields_from_project_data(project_data: dict) -> dict: - """ - Remove budget fields from project data. - Budget fields belong to LiteLLM_BudgetTable, not LiteLLM_ProjectTable. - Keep budget_id as it's a foreign key. - - Following the pattern from organization_endpoints.py - """ - budget_fields = LiteLLM_BudgetTable.model_fields.keys() - for field in list(budget_fields): - if field != "budget_id": # Keep the foreign key - project_data.pop(field, None) - return project_data - - -@router.post( - "/project/new", - tags=["project management"], - dependencies=[Depends(user_api_key_auth)], - response_model=NewProjectResponse, -) -@management_endpoint_wrapper -async def new_project( - data: NewProjectRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Create a new project. Projects sit between teams and keys in the hierarchy. - - Only admins or team admins can create projects. - - # Parameters - - - project_alias: *Optional[str]* - The name of the project. - - description: *Optional[str]* - Description of the project's purpose and use case. - - team_id: *str* - The team id that this project belongs to. Required. - - models: *List* - The models the project has access to. - - budget_id: *Optional[str]* - The id for a budget (tpm/rpm/max budget) for the project. - ### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ### - - max_budget: *Optional[float]* - Max budget for project - - tpm_limit: *Optional[int]* - Max tpm limit for project - - rpm_limit: *Optional[int]* - Max rpm limit for project - - max_parallel_requests: *Optional[int]* - Max parallel requests for project - - soft_budget: *Optional[float]* - Get a slack alert when this soft budget is reached. Don't block requests. - - model_max_budget: *Optional[dict]* - Max budget for a specific model. Example: {"gpt-4": 100.0, "gpt-3.5-turbo": 50.0} - - model_rpm_limit: *Optional[dict]* - RPM limits per model. Example: {"gpt-4": 1000, "gpt-3.5-turbo": 5000} - - model_tpm_limit: *Optional[dict]* - TPM limits per model. Example: {"gpt-4": 50000, "gpt-3.5-turbo": 100000} - - budget_duration: *Optional[str]* - Frequency of reseting project budget - - metadata: *Optional[dict]* - Metadata for project, store information for project. Example metadata - {"use_case_id": "SNOW-12345", "responsible_ai_id": "RAI-67890"} - - tags: *Optional[list]* - Tags for the project. Example: ["production", "api"] - - blocked: *bool* - Flag indicating if the project is blocked or not - will stop all calls from keys with this project_id. - - object_permission: Optional[LiteLLM_ObjectPermissionBase] - project-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission. - - Example 1: Create new project **without** a budget_id, with model-specific limits - - ```bash - curl --location 'http://0.0.0.0:4000/project/new' \\ - --header 'Authorization: Bearer sk-1234' \\ - --header 'Content-Type: application/json' \\ - --data '{ - "project_alias": "flight-search-assistant", - "description": "AI-powered flight search and booking assistant", - "team_id": "team-123", - "models": ["gpt-4", "gpt-3.5-turbo"], - "max_budget": 100, - "model_rpm_limit": { - "gpt-4": 1000, - "gpt-3.5-turbo": 5000 - }, - "model_tpm_limit": { - "gpt-4": 50000, - "gpt-3.5-turbo": 100000 - }, - "metadata": { - "use_case_id": "SNOW-12345", - "responsible_ai_id": "RAI-67890" - } - }' - ``` - - Example 2: Create new project **with** a budget_id - - ```bash - curl --location 'http://0.0.0.0:4000/project/new' \\ - --header 'Authorization: Bearer sk-1234' \\ - --header 'Content-Type: application/json' \\ - --data '{ - "project_alias": "hotel-recommendations", - "description": "Personalized hotel recommendation engine", - "team_id": "team-123", - "models": ["claude-3-sonnet"], - "budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689", - "metadata": { - "use_case_id": "SNOW-54321" - } - }' - ``` - """ - from litellm.proxy.proxy_server import ( - litellm_proxy_admin_name, - premium_user, - prisma_client, - ) - - try: - if getattr(data, "tags", None) is not None and not premium_user: - raise HTTPException( - status_code=403, - detail={ - "error": "Only premium users can add tags to projects. " - + CommonProxyErrors.not_premium_user.value - }, - ) - - if not premium_user: - raise HTTPException( - status_code=403, - detail={ - "error": "Project management is an enterprise feature. " - + CommonProxyErrors.not_premium_user.value - }, - ) - - # ADD METADATA FIELDS - for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium: - if getattr(data, field, None) is not None: - _set_object_metadata_field( - object_data=data, - field_name=field, - value=getattr(data, field), - ) - delattr(data, field) - - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - # Validate team exists and get team object with budget - team_object = await _validate_team_exists( - team_id=data.team_id, prisma_client=prisma_client - ) - - # Validate project limits against team limits - _check_team_project_limits( - team_object=LiteLLM_TeamTable(**team_object.model_dump()), - data=data, - ) - - # Check if user has permission to create projects for this team - # only team admins can create projects for their team - has_permission = await _check_user_permission_for_project( - user_api_key_dict=user_api_key_dict, - team_id=data.team_id, - prisma_client=prisma_client, - team_object=LiteLLM_TeamTable(**team_object.model_dump()), - ) - - if not has_permission: - raise HTTPException( - status_code=403, - detail={ - "error": f"Only admins or team admins can create projects. Your role is {user_api_key_dict.user_role}" - }, - ) - - # Generate project_id if not provided - if data.project_id is None: - data.project_id = str(uuid.uuid4()) - else: - # Check if project_id already exists - existing_project = await prisma_client.db.litellm_projecttable.find_unique( - where={"project_id": data.project_id} - ) - if existing_project is not None: - raise ProxyException( - message=f"Project id = {data.project_id} already exists. Please use a different project id.", - type="bad_request", - code=400, - param="project_id", - ) - - # Create budget if not provided - if data.budget_id is None: - data.budget_id = await _create_budget_for_project( - data=data, - user_id=user_api_key_dict.user_id, - litellm_proxy_admin_name=litellm_proxy_admin_name, - prisma_client=prisma_client, - ) - - ## Handle Object Permission - MCP, Vector Stores etc. - object_permission_id = await _set_project_object_permission( - data=data, - prisma_client=prisma_client, - ) - - # Create project row (following organization_endpoints.py pattern) - project_row = LiteLLM_ProjectTable( - **data.json(exclude_none=True), - object_permission_id=object_permission_id, - created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - ) - - for field in LiteLLM_ManagementEndpoint_MetadataFields: - if getattr(data, field, None) is not None: - _set_object_metadata_field( - object_data=project_row, - field_name=field, - value=getattr(data, field), - ) - - new_project_row = prisma_client.jsonify_object( - project_row.json(exclude_none=True) - ) - - # Remove budget fields (following organization_endpoints.py pattern) - new_project_row = _remove_budget_fields_from_project_data(new_project_row) - - verbose_proxy_logger.info( - f"new_project_row: {json.dumps(new_project_row, indent=2)}" - ) - response = await prisma_client.db.litellm_projecttable.create( - data={ - **new_project_row, # type: ignore - }, - include={"litellm_budget_table": True}, - ) - - return response - except Exception as e: - verbose_proxy_logger.exception( - "litellm.proxy.management_endpoints.project_endpoints.new_project(): Exception occured - {}".format( - str(e) - ) - ) - raise handle_exception_on_proxy(e) - - -@router.post( - "/project/update", - tags=["project management"], - dependencies=[Depends(user_api_key_auth)], - response_model=LiteLLM_ProjectTable, -) -@management_endpoint_wrapper -async def update_project( # noqa: PLR0915 - data: UpdateProjectRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Update a project - - Parameters: - - project_id: *str* - The project id to update. Required. - - project_alias: *Optional[str]* - Updated name for the project - - description: *Optional[str]* - Updated description for the project - - team_id: *Optional[str]* - Updated team_id for the project - - metadata: *Optional[dict]* - Updated metadata for project - - models: *Optional[list]* - Updated list of models for the project - - blocked: *Optional[bool]* - Updated blocked status - - max_budget: *Optional[float]* - Updated max budget - - tpm_limit: *Optional[int]* - Updated tpm limit - - rpm_limit: *Optional[int]* - Updated rpm limit - - model_rpm_limit: *Optional[dict]* - Updated RPM limits per model - - model_tpm_limit: *Optional[dict]* - Updated TPM limits per model - - budget_duration: *Optional[str]* - Updated budget duration - - tags: *Optional[list]* - Updated list of tags for the project - - object_permission: Optional[LiteLLM_ObjectPermissionBase] - Updated object permission - - Example: - ```bash - curl --location 'http://0.0.0.0:4000/project/update' \\ - --header 'Authorization: Bearer sk-1234' \\ - --header 'Content-Type: application/json' \\ - --data '{ - "project_id": "project-123", - "description": "Updated flight search system with enhanced capabilities", - "max_budget": 200, - "model_rpm_limit": { - "gpt-4": 2000, - "gpt-3.5-turbo": 10000 - }, - "metadata": { - "use_case_id": "SNOW-12345", - "status": "active" - } - }' - ``` - """ - from litellm.proxy.proxy_server import ( - litellm_proxy_admin_name, - premium_user, - prisma_client, - ) - - try: - if getattr(data, "tags", None) is not None and not premium_user: - raise HTTPException( - status_code=403, - detail={ - "error": "Only premium users can add tags to projects. " - + CommonProxyErrors.not_premium_user.value - }, - ) - - if not premium_user: - raise HTTPException( - status_code=403, - detail={ - "error": "Project management is an enterprise feature. " - + CommonProxyErrors.not_premium_user.value - }, - ) - - # ADD METADATA FIELDS - for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium: - if getattr(data, field, None) is not None: - _set_object_metadata_field( - object_data=data, - field_name=field, - value=getattr(data, field), - ) - delattr(data, field) - - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - if data.project_id is None: - raise HTTPException( - status_code=400, - detail={"error": "project_id is required"}, - ) - - # Fetch existing project - existing_project = await prisma_client.db.litellm_projecttable.find_unique( - where={"project_id": data.project_id} - ) - - if existing_project is None: - raise ProxyException( - message=f"Project not found, project_id={data.project_id}", - type="not_found", - code=404, - param="project_id", - ) - - # Permission to *edit* the project must be evaluated against the - # project's CURRENT team. Sourcing the team from `data.team_id` - # would let an admin of any team pass the check by supplying their - # own team_id, hijacking the project (VERIA-55). - target_team_id = data.team_id or existing_project.team_id - target_team_obj = None - if target_team_id is not None: - target_team_obj = await _validate_team_exists( - team_id=target_team_id, prisma_client=prisma_client - ) - - has_permission = await _check_user_permission_for_project( - user_api_key_dict=user_api_key_dict, - team_id=existing_project.team_id, - prisma_client=prisma_client, - ) - - if not has_permission: - raise HTTPException( - status_code=403, - detail={"error": "Only admins or team admins can update projects"}, - ) - - # Reassigning to a different team also requires admin rights on the - # destination team — otherwise a team admin could shed projects into - # an unsuspecting team's namespace. - if data.team_id is not None and data.team_id != existing_project.team_id: - can_assign_to_target = await _check_user_permission_for_project( - user_api_key_dict=user_api_key_dict, - team_id=data.team_id, - prisma_client=prisma_client, - team_object=( - LiteLLM_TeamTable(**target_team_obj.model_dump()) - if target_team_obj - else None - ), - ) - if not can_assign_to_target: - raise HTTPException( - status_code=403, - detail={ - "error": "Cannot reassign project to a team you are not an admin of" - }, - ) - - # Validate project limits against team limits - if target_team_obj is not None: - _check_team_project_limits( - team_object=LiteLLM_TeamTable(**target_team_obj.model_dump()), - data=data, - ) - - # Prepare update data - update_data = data.json(exclude_none=True, exclude={"project_id"}) - update_data = prisma_client.jsonify_object(update_data) - update_data["updated_by"] = ( - user_api_key_dict.user_id or litellm_proxy_admin_name - ) - - # Handle budget updates - budget_fields = LiteLLM_BudgetTable.model_fields.keys() - budget_updates = {k: v for k, v in update_data.items() if k in budget_fields} - - if budget_updates and existing_project.budget_id: - # Update existing budget - await prisma_client.db.litellm_budgettable.update( - where={"budget_id": existing_project.budget_id}, - data={ - **budget_updates, - "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - }, - ) - # Remove budget fields from project update - for field in budget_updates.keys(): - update_data.pop(field, None) - - # Handle object permissions - if "object_permission" in update_data: - object_permission_data = update_data.pop("object_permission") - if object_permission_data: - if existing_project.object_permission_id: - # Update existing permission - await prisma_client.db.litellm_objectpermissiontable.update( - where={ - "object_permission_id": existing_project.object_permission_id - }, - data=object_permission_data, - ) - else: - # Create new permission - created_permission = ( - await prisma_client.db.litellm_objectpermissiontable.create( - data=object_permission_data, - ) - ) - update_data["object_permission_id"] = ( - created_permission.object_permission_id - ) - - # Handle metadata fields - for field in LiteLLM_ManagementEndpoint_MetadataFields: - if field in update_data: - if update_data.get("metadata") is None: - update_data["metadata"] = {} - update_data["metadata"][field] = update_data.pop(field) - - # Remove budget fields (following organization_endpoints.py pattern) - update_data = _remove_budget_fields_from_project_data(update_data) - - # Update project - updated_project = await prisma_client.db.litellm_projecttable.update( - where={"project_id": data.project_id}, - data=update_data, - include={"litellm_budget_table": True, "object_permission": True}, - ) - - return updated_project - except Exception as e: - verbose_proxy_logger.exception( - "litellm.proxy.management_endpoints.project_endpoints.update_project(): Exception occured - {}".format( - str(e) - ) - ) - raise handle_exception_on_proxy(e) - - -@router.delete( - "/project/delete", - tags=["project management"], - dependencies=[Depends(user_api_key_auth)], - response_model=List[LiteLLM_ProjectTable], -) -@management_endpoint_wrapper -async def delete_project( - data: DeleteProjectRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Delete projects - - Parameters: - - project_ids: *List[str]* - List of project ids to delete - - Example: - ```bash - curl --location --request DELETE 'http://0.0.0.0:4000/project/delete' \\ - --header 'Authorization: Bearer sk-1234' \\ - --header 'Content-Type: application/json' \\ - --data '{ - "project_ids": ["project-123", "project-456"] - }' - ``` - """ - from litellm.proxy.proxy_server import premium_user, prisma_client - - try: - if not premium_user: - raise HTTPException( - status_code=403, - detail={ - "error": "Project management is an enterprise feature. " - + CommonProxyErrors.not_premium_user.value - }, - ) - - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - # Check if user is admin (only admins can delete projects) - has_permission = await _check_user_permission_for_project( - user_api_key_dict=user_api_key_dict, - team_id=None, - prisma_client=prisma_client, - require_admin=True, - ) - - if not has_permission: - raise HTTPException( - status_code=403, - detail={"error": "Only admins can delete projects"}, - ) - - deleted_projects = [] - - for project_id in data.project_ids: - # Check if project exists - existing_project = await prisma_client.db.litellm_projecttable.find_unique( - where={"project_id": project_id} - ) - - if existing_project is None: - raise ProxyException( - message=f"Project not found, project_id={project_id}", - type="not_found", - code=404, - param="project_ids", - ) - - # Check if there are any keys associated with this project - associated_keys = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"project_id": project_id} - ) - ) - - if len(associated_keys) > 0: - raise ProxyException( - message=f"Cannot delete project {project_id}. {len(associated_keys)} key(s) are associated with it. Please delete or reassign the keys first.", - type="bad_request", - code=400, - param="project_ids", - ) - - # Delete the project - deleted_project = await prisma_client.db.litellm_projecttable.delete( - where={"project_id": project_id} - ) - - deleted_projects.append(deleted_project) - - return deleted_projects - except Exception as e: - verbose_proxy_logger.exception( - "litellm.proxy.management_endpoints.project_endpoints.delete_project(): Exception occured - {}".format( - str(e) - ) - ) - raise handle_exception_on_proxy(e) - - -@router.get( - "/project/info", - tags=["project management"], - dependencies=[Depends(user_api_key_auth)], - response_model=LiteLLM_ProjectTable, -) -async def project_info( - project_id: str, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Get information about a specific project - - Parameters: - - project_id: *str* - The project id to fetch info for - - Example: - ```bash - curl --location 'http://0.0.0.0:4000/project/info?project_id=project-123' \\ - --header 'Authorization: Bearer sk-1234' - ``` - """ - from litellm.proxy.proxy_server import prisma_client - - try: - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - # Fetch project - project = await prisma_client.db.litellm_projecttable.find_unique( - where={"project_id": project_id}, - include={"litellm_budget_table": True, "object_permission": True}, - ) - - if project is None: - raise ProxyException( - message=f"Project not found, project_id={project_id}", - type="not_found", - code=404, - param="project_id", - ) - - # Check if user has access to this project (admin or team member) - is_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN - is_team_member = False - - if project.team_id and user_api_key_dict.user_id: - team = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": project.team_id} - ) - if team: - caller_user_id = user_api_key_dict.user_id - for m in team.members_with_roles or []: - m_user_id = ( - m.get("user_id") - if isinstance(m, dict) - else getattr(m, "user_id", None) - ) - if m_user_id == caller_user_id: - is_team_member = True - break - - if not (is_admin or is_team_member): - raise HTTPException( - status_code=403, - detail={"error": "You don't have access to this project"}, - ) - - return project - except Exception as e: - verbose_proxy_logger.exception( - "litellm.proxy.management_endpoints.project_endpoints.project_info(): Exception occured - {}".format( - str(e) - ) - ) - raise handle_exception_on_proxy(e) - - -@router.get( - "/project/list", - tags=["project management"], - dependencies=[Depends(user_api_key_auth)], - response_model=List[LiteLLM_ProjectTable], -) -async def list_projects( - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - List all projects that the user has access to - - Example: - ```bash - curl --location 'http://0.0.0.0:4000/project/list' \\ - --header 'Authorization: Bearer sk-1234' - ``` - """ - from litellm.proxy.proxy_server import prisma_client - - try: - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - # If proxy admin, get all projects - if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN: - projects = await prisma_client.db.litellm_projecttable.find_many( - include={"litellm_budget_table": True, "object_permission": True} - ) - else: - # Look up the user's team memberships via the reverse-index on - # LiteLLM_UserTable.teams (maintained by team_member_add alongside - # members_with_roles). This avoids a full scan of all team rows. - user_record = await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - ) - user_team_ids = ( - user_record.teams - if user_record is not None and user_record.teams - else [] - ) - - projects = await prisma_client.db.litellm_projecttable.find_many( - where={"team_id": {"in": user_team_ids}}, - include={"litellm_budget_table": True, "object_permission": True}, - ) - - return projects - except Exception as e: - verbose_proxy_logger.exception( - "litellm.proxy.management_endpoints.project_endpoints.list_projects(): Exception occured - {}".format( - str(e) - ) - ) - raise handle_exception_on_proxy(e) diff --git a/enterprise/litellm_enterprise/proxy/proxy_server.py b/enterprise/litellm_enterprise/proxy/proxy_server.py deleted file mode 100644 index 79d3ebdf9e..0000000000 --- a/enterprise/litellm_enterprise/proxy/proxy_server.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -from typing import Optional - -from litellm_enterprise.types.proxy.proxy_server import CustomAuthSettings - -custom_auth_settings: Optional[CustomAuthSettings] = None - - -class EnterpriseProxyConfig: - async def load_custom_auth_settings( - self, general_settings: dict - ) -> CustomAuthSettings: - custom_auth_settings = general_settings.get("custom_auth_settings", None) - if custom_auth_settings is not None: - custom_auth_settings = CustomAuthSettings( - mode=custom_auth_settings.get("mode"), - ) - return custom_auth_settings - - async def load_enterprise_config(self, general_settings: dict) -> None: - global custom_auth_settings - custom_auth_settings = await self.load_custom_auth_settings(general_settings) - return None - - @staticmethod - def get_custom_docs_description() -> Optional[str]: - from litellm.proxy.proxy_server import premium_user - - docs_description: Optional[str] = None - if premium_user: - # check if premium_user has custom_docs_description - docs_description = os.getenv("DOCS_DESCRIPTION") - - return docs_description diff --git a/enterprise/litellm_enterprise/proxy/readme.md b/enterprise/litellm_enterprise/proxy/readme.md deleted file mode 100644 index 60b07cf49a..0000000000 --- a/enterprise/litellm_enterprise/proxy/readme.md +++ /dev/null @@ -1,11 +0,0 @@ -# LiteLLM Proxy Enterprise Features - Readme - -## Overview - -This directory contains enterprise features used on the LiteLLM proxy. - -## Format - -Create a file for every group of endpoints (e.g. `key_management_endpoints.py`, `user_management_endpoints.py`, etc.) - -If there is a broader semantic group of endpoints, create a folder for that group (e.g. `management_endpoints`, `auth_endpoints`, etc.) diff --git a/enterprise/litellm_enterprise/proxy/ui_crud_endpoints/__init__.py b/enterprise/litellm_enterprise/proxy/ui_crud_endpoints/__init__.py deleted file mode 100644 index 296d964f85..0000000000 --- a/enterprise/litellm_enterprise/proxy/ui_crud_endpoints/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import ui_settings_extensions # side-effect: registers extra UI settings fields - -__all__ = ["ui_settings_extensions"] diff --git a/enterprise/litellm_enterprise/proxy/ui_crud_endpoints/ui_settings_extensions.py b/enterprise/litellm_enterprise/proxy/ui_crud_endpoints/ui_settings_extensions.py deleted file mode 100644 index e61611ae59..0000000000 --- a/enterprise/litellm_enterprise/proxy/ui_crud_endpoints/ui_settings_extensions.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Enterprise-only UI settings fields. - -Registers additional fields onto the OSS ``UISettings`` model at import time. -Importing this module has the side effect of extending both the GET schema -and the PATCH allowlist served by ``/get/ui_settings`` and -``/update/ui_settings``. -""" - -from pydantic.fields import FieldInfo - -from litellm.proxy.ui_crud_endpoints.proxy_setting_endpoints import ( - register_extra_ui_setting, -) - -register_extra_ui_setting( - "enable_projects_ui", - bool, - FieldInfo( - default=False, - description=( - "If enabled, shows the Projects feature in the UI sidebar and " - "the project field in key management." - ), - ), -) diff --git a/enterprise/litellm_enterprise/proxy/utils.py b/enterprise/litellm_enterprise/proxy/utils.py deleted file mode 100644 index 227ea0a9ff..0000000000 --- a/enterprise/litellm_enterprise/proxy/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Optional, Union - -from litellm.secret_managers.main import str_to_bool - - -def _should_block_robots(): - """ - Returns True if the robots.txt file should block web crawlers - - Controlled by - - ```yaml - general_settings: - block_robots: true - ``` - """ - from litellm.proxy.proxy_server import ( - CommonProxyErrors, - general_settings, - premium_user, - ) - - _block_robots: Union[bool, str] = general_settings.get("block_robots", False) - block_robots: Optional[bool] = None - if isinstance(_block_robots, bool): - block_robots = _block_robots - elif isinstance(_block_robots, str): - block_robots = str_to_bool(_block_robots) - if block_robots is True: - if premium_user is not True: - raise ValueError( - f"Blocking web crawlers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}" - ) - return True - return False diff --git a/enterprise/litellm_enterprise/proxy/vector_stores/endpoints.py b/enterprise/litellm_enterprise/proxy/vector_stores/endpoints.py deleted file mode 100644 index 5e79959986..0000000000 --- a/enterprise/litellm_enterprise/proxy/vector_stores/endpoints.py +++ /dev/null @@ -1,347 +0,0 @@ -""" -VECTOR STORE MANAGEMENT - -All /vector_store management endpoints - -/vector_store/new -/vector_store/delete -/vector_store/list -""" - -import copy -import json -from typing import List, Optional - -from fastapi import APIRouter, Depends, HTTPException - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.litellm_core_utils.safe_json_dumps import safe_dumps -from litellm.proxy._types import ( - LiteLLM_ManagedVectorStoresTable, - ResponseLiteLLM_ManagedVectorStore, - UserAPIKeyAuth, -) -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.types.vector_stores import ( - LiteLLM_ManagedVectorStore, - LiteLLM_ManagedVectorStoreListResponse, - VectorStoreDeleteRequest, - VectorStoreInfoRequest, - VectorStoreUpdateRequest, -) -from litellm.vector_stores.vector_store_registry import VectorStoreRegistry - -router = APIRouter() - - -######################################################## -# Management Endpoints -######################################################## -@router.post( - "/vector_store/new", - tags=["vector store management"], - dependencies=[Depends(user_api_key_auth)], -) -async def new_vector_store( - vector_store: LiteLLM_ManagedVectorStore, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Create a new vector store. - - Parameters: - - vector_store_id: str - Unique identifier for the vector store - - custom_llm_provider: str - Provider of the vector store - - vector_store_name: Optional[str] - Name of the vector store - - vector_store_description: Optional[str] - Description of the vector store - - vector_store_metadata: Optional[Dict] - Additional metadata for the vector store - """ - from litellm.proxy.proxy_server import prisma_client - from litellm.types.router import GenericLiteLLMParams - - if prisma_client is None: - raise HTTPException(status_code=500, detail="Database not connected") - - try: - # Check if vector store already exists - existing_vector_store = ( - await prisma_client.db.litellm_managedvectorstorestable.find_unique( - where={"vector_store_id": vector_store.get("vector_store_id")} - ) - ) - if existing_vector_store is not None: - raise HTTPException( - status_code=400, - detail=f"Vector store with ID {vector_store.get('vector_store_id')} already exists", - ) - - if vector_store.get("vector_store_metadata") is not None: - vector_store["vector_store_metadata"] = safe_dumps( - vector_store.get("vector_store_metadata") - ) - - # Safely handle JSON serialization of litellm_params - litellm_params_json: Optional[str] = None - _input_litellm_params: dict = vector_store.get("litellm_params", {}) or {} - if _input_litellm_params is not None: - litellm_params_dict = GenericLiteLLMParams( - **_input_litellm_params - ).model_dump(exclude_none=True) - litellm_params_json = safe_dumps(litellm_params_dict) - del vector_store["litellm_params"] - - _new_vector_store = ( - await prisma_client.db.litellm_managedvectorstorestable.create( - data={ - **vector_store, - "litellm_params": litellm_params_json, - } - ) - ) - - new_vector_store: LiteLLM_ManagedVectorStore = LiteLLM_ManagedVectorStore( - **_new_vector_store.model_dump() - ) - - # Add vector store to registry - if litellm.vector_store_registry is not None: - litellm.vector_store_registry.add_vector_store_to_registry( - vector_store=new_vector_store - ) - - return { - "status": "success", - "message": f"Vector store {vector_store.get('vector_store_id')} created successfully", - "vector_store": new_vector_store, - } - except Exception as e: - verbose_proxy_logger.exception(f"Error creating vector store: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get( - "/vector_store/list", - tags=["vector store management"], - dependencies=[Depends(user_api_key_auth)], - response_model=LiteLLM_ManagedVectorStoreListResponse, -) -async def list_vector_stores( - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - page: int = 1, - page_size: int = 100, -): - """ - List all available vector stores with optional filtering and pagination. - Combines both in-memory vector stores and those stored in the database. - - Parameters: - - page: int - Page number for pagination (default: 1) - - page_size: int - Number of items per page (default: 100) - """ - from litellm.proxy.proxy_server import prisma_client - - try: - # Get vector stores from database (source of truth) - # Only return what's in the database to ensure consistency across instances - vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db( - prisma_client=prisma_client - ) - - # Also clean up in-memory registry to remove any deleted vector stores - if litellm.vector_store_registry is not None: - db_vector_store_ids = { - vs.get("vector_store_id") - for vs in vector_stores_from_db - if vs.get("vector_store_id") - } - # Remove any in-memory vector stores that no longer exist in database - vector_stores_to_remove = [] - for vs in litellm.vector_store_registry.vector_stores: - vs_id = vs.get("vector_store_id") - if vs_id and vs_id not in db_vector_store_ids: - vector_stores_to_remove.append(vs_id) - for vs_id in vector_stores_to_remove: - litellm.vector_store_registry.delete_vector_store_from_registry( - vector_store_id=vs_id - ) - verbose_proxy_logger.debug( - f"Removed deleted vector store {vs_id} from in-memory registry" - ) - - # Use database as single source of truth for listing - combined_vector_stores: List[LiteLLM_ManagedVectorStore] = vector_stores_from_db - - total_count = len(combined_vector_stores) - total_pages = (total_count + page_size - 1) // page_size - - # Format response using LiteLLM_ManagedVectorStoreListResponse - response = LiteLLM_ManagedVectorStoreListResponse( - object="list", - data=combined_vector_stores, - total_count=total_count, - current_page=page, - total_pages=total_pages, - ) - - return response - except Exception as e: - verbose_proxy_logger.exception(f"Error listing vector stores: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/vector_store/delete", - tags=["vector store management"], - dependencies=[Depends(user_api_key_auth)], -) -async def delete_vector_store( - data: VectorStoreDeleteRequest, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Delete a vector store. - - Parameters: - - vector_store_id: str - ID of the vector store to delete - """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail="Database not connected") - - try: - # Check if vector store exists - existing_vector_store = ( - await prisma_client.db.litellm_managedvectorstorestable.find_unique( - where={"vector_store_id": data.vector_store_id} - ) - ) - if existing_vector_store is None: - raise HTTPException( - status_code=404, - detail=f"Vector store with ID {data.vector_store_id} not found", - ) - - # Delete vector store - await prisma_client.db.litellm_managedvectorstorestable.delete( - where={"vector_store_id": data.vector_store_id} - ) - - # Delete vector store from registry - if litellm.vector_store_registry is not None: - litellm.vector_store_registry.delete_vector_store_from_registry( - vector_store_id=data.vector_store_id - ) - - return {"message": f"Vector store {data.vector_store_id} deleted successfully"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/vector_store/info", - tags=["vector store management"], - dependencies=[Depends(user_api_key_auth)], - response_model=ResponseLiteLLM_ManagedVectorStore, -) -async def get_vector_store_info( - data: VectorStoreInfoRequest, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """Return a single vector store's details""" - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail="Database not connected") - - try: - if litellm.vector_store_registry is not None: - vector_store = litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry( - vector_store_id=data.vector_store_id - ) - if vector_store is not None: - vector_store_metadata = vector_store.get("vector_store_metadata") - # Parse metadata if it's a JSON string - parsed_metadata: Optional[dict] = None - if isinstance(vector_store_metadata, str): - parsed_metadata = json.loads(vector_store_metadata) - elif isinstance(vector_store_metadata, dict): - parsed_metadata = vector_store_metadata - - vector_store_pydantic_obj = LiteLLM_ManagedVectorStoresTable( - vector_store_id=vector_store.get("vector_store_id") or "", - custom_llm_provider=vector_store.get("custom_llm_provider") or "", - vector_store_name=vector_store.get("vector_store_name") or None, - vector_store_description=vector_store.get( - "vector_store_description" - ) - or None, - vector_store_metadata=parsed_metadata, - created_at=vector_store.get("created_at") or None, - updated_at=vector_store.get("updated_at") or None, - litellm_credential_name=vector_store.get("litellm_credential_name"), - litellm_params=vector_store.get("litellm_params") or None, - team_id=vector_store.get("team_id"), - user_id=vector_store.get("user_id"), - ) - return {"vector_store": vector_store_pydantic_obj} - - vector_store = ( - await prisma_client.db.litellm_managedvectorstorestable.find_unique( - where={"vector_store_id": data.vector_store_id} - ) - ) - if vector_store is None: - raise HTTPException( - status_code=404, - detail=f"Vector store with ID {data.vector_store_id} not found", - ) - - vector_store_dict = vector_store.model_dump() # type: ignore[attr-defined] - return {"vector_store": vector_store_dict} - except Exception as e: - verbose_proxy_logger.exception(f"Error getting vector store info: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/vector_store/update", - tags=["vector store management"], - dependencies=[Depends(user_api_key_auth)], -) -async def update_vector_store( - data: VectorStoreUpdateRequest, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """Update vector store details""" - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail="Database not connected") - - try: - update_data = data.model_dump(exclude_unset=True) - vector_store_id = update_data.pop("vector_store_id") - if update_data.get("vector_store_metadata") is not None: - update_data["vector_store_metadata"] = safe_dumps( - update_data["vector_store_metadata"] - ) - - updated = await prisma_client.db.litellm_managedvectorstorestable.update( - where={"vector_store_id": vector_store_id}, - data=update_data, - ) - - updated_vs = LiteLLM_ManagedVectorStore(**updated.model_dump()) - - if litellm.vector_store_registry is not None: - litellm.vector_store_registry.update_vector_store_in_registry( - vector_store_id=vector_store_id, - updated_data=updated_vs, - ) - - return {"vector_store": updated_vs} - except Exception as e: - verbose_proxy_logger.exception(f"Error updating vector store: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) diff --git a/enterprise/litellm_enterprise/types/enterprise_callbacks/send_emails.py b/enterprise/litellm_enterprise/types/enterprise_callbacks/send_emails.py deleted file mode 100644 index 380b0a6fac..0000000000 --- a/enterprise/litellm_enterprise/types/enterprise_callbacks/send_emails.py +++ /dev/null @@ -1,66 +0,0 @@ -import enum -from typing import Dict, List, Optional - -from pydantic import BaseModel, Field - -from litellm.proxy._types import WebhookEvent - - -class EmailParams(BaseModel): - logo_url: str - support_contact: str - base_url: str - recipient_email: str - subject: str - signature: str - - -class SendKeyCreatedEmailEvent(WebhookEvent): - virtual_key: str - """ - The virtual key that was created - this will be sk-123xxx, since we will be emailing this to the user to start using the key - """ - - -class SendKeyRotatedEmailEvent(WebhookEvent): - virtual_key: str - key_alias: Optional[str] = None - """ - The virtual key that was rotated - this will be sk-123xxx, since we will be emailing this to the user to start using the new key - """ - - -class EmailEvent(str, enum.Enum): - virtual_key_created = "Virtual Key Created" - new_user_invitation = "New User Invitation" - virtual_key_rotated = "Virtual Key Rotated" - soft_budget_crossed = "Soft Budget Crossed" - max_budget_alert = "Max Budget Alert" - -class EmailEventSettings(BaseModel): - event: EmailEvent - enabled: bool -class EmailEventSettingsUpdateRequest(BaseModel): - settings: List[EmailEventSettings] -class EmailEventSettingsResponse(BaseModel): - settings: List[EmailEventSettings] -class DefaultEmailSettings(BaseModel): - """Default settings for email events""" - settings: Dict[EmailEvent, bool] = Field( - default_factory=lambda: { - EmailEvent.virtual_key_created: True, # On by default - EmailEvent.new_user_invitation: True, # On by default - EmailEvent.virtual_key_rotated: True, # On by default - EmailEvent.soft_budget_crossed: True, # On by default - EmailEvent.max_budget_alert: True, # On by default - } - ) - def to_dict(self) -> Dict[str, bool]: - """Convert to dictionary with string keys for storage""" - return {event.value: enabled for event, enabled in self.settings.items()} - @classmethod - def get_defaults(cls) -> Dict[str, bool]: - """Get the default settings as a dictionary with string keys""" - return cls().to_dict() \ No newline at end of file diff --git a/enterprise/litellm_enterprise/types/proxy/audit_logging_endpoints.py b/enterprise/litellm_enterprise/types/proxy/audit_logging_endpoints.py deleted file mode 100644 index 4615bde2b1..0000000000 --- a/enterprise/litellm_enterprise/types/proxy/audit_logging_endpoints.py +++ /dev/null @@ -1,30 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field - - -class AuditLogResponse(BaseModel): - """Response model for a single audit log entry""" - - id: str - updated_at: datetime - changed_by: str - changed_by_api_key: str - action: str - table_name: str - object_id: str - before_value: Optional[Dict[str, Any]] = None - updated_values: Optional[Dict[str, Any]] = None - - -class PaginatedAuditLogResponse(BaseModel): - """Response model for paginated audit logs""" - - audit_logs: List[AuditLogResponse] - total: int = Field( - ..., description="Total number of audit logs matching the filters" - ) - page: int = Field(..., description="Current page number") - page_size: int = Field(..., description="Number of items per page") - total_pages: int = Field(..., description="Total number of pages") diff --git a/enterprise/litellm_enterprise/types/proxy/proxy_server.py b/enterprise/litellm_enterprise/types/proxy/proxy_server.py deleted file mode 100644 index f1a1f2639e..0000000000 --- a/enterprise/litellm_enterprise/types/proxy/proxy_server.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Literal - -from typing_extensions import TypedDict - - -class CustomAuthSettings(TypedDict): - mode: Literal["on", "off", "auto"] diff --git a/enterprise/pyproject.toml b/enterprise/pyproject.toml deleted file mode 100644 index d043244843..0000000000 --- a/enterprise/pyproject.toml +++ /dev/null @@ -1,33 +0,0 @@ -[project] -name = "litellm-enterprise" -version = "0.1.42" -description = "Package for LiteLLM Enterprise features" -readme = "README.md" -requires-python = ">=3.9" -license = "LicenseRef-Proprietary" -license-files = ["LICENSE.md"] -authors = [ - { name = "BerriAI" }, -] - -[project.urls] -Homepage = "https://litellm.ai" -Repository = "https://github.com/BerriAI/litellm" -Documentation = "https://docs.litellm.ai" - -[build-system] -requires = ["uv_build==0.11.8"] -build-backend = "uv_build" - -[tool.uv] -required-version = ">=0.10.9" - -[tool.uv.build-backend] -module-root = "" - -[tool.commitizen] -version = "0.1.42" -version_files = [ - "pyproject.toml:^version", - "../pyproject.toml:litellm-enterprise==", -] diff --git a/litellm/proxy/enterprise b/litellm/proxy/enterprise deleted file mode 120000 index 6ee73080d0..0000000000 --- a/litellm/proxy/enterprise +++ /dev/null @@ -1 +0,0 @@ -../../enterprise \ No newline at end of file diff --git a/tests/enterprise/conftest.py b/tests/enterprise/conftest.py deleted file mode 100644 index 0365bbbcfa..0000000000 --- a/tests/enterprise/conftest.py +++ /dev/null @@ -1,78 +0,0 @@ -# conftest.py - -import asyncio -import importlib -import os -import sys - -import pytest - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import litellm - - -@pytest.fixture(scope="session") -def event_loop(): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - yield loop - loop.close() - - - - -@pytest.fixture(scope="function", autouse=True) -def setup_and_teardown(): - """ - This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. - """ - curr_dir = os.getcwd() # Get the current working directory - sys.path.insert( - 0, os.path.abspath("../..") - ) # Adds the project directory to the system path - - import litellm - from litellm import Router - - importlib.reload(litellm) - - try: - if hasattr(litellm, "proxy") and hasattr(litellm.proxy, "proxy_server"): - import litellm.proxy.proxy_server - - importlib.reload(litellm.proxy.proxy_server) - except Exception as e: - print(f"Error reloading litellm.proxy.proxy_server: {e}") - - litellm.in_memory_llm_clients_cache.flush_cache() - - import asyncio - - loop = asyncio.get_event_loop_policy().new_event_loop() - asyncio.set_event_loop(loop) - print(litellm) - # from litellm import Router, completion, aembedding, acompletion, embedding - yield - - # Teardown code (executes after the yield point) - loop.close() # Close the loop created earlier - asyncio.set_event_loop(None) # Remove the reference to the loop - - -def pytest_collection_modifyitems(config, items): - # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests - custom_logger_tests = [ - item for item in items if "custom_logger" in item.parent.name - ] - other_tests = [item for item in items if "custom_logger" not in item.parent.name] - - # Sort tests based on their names - custom_logger_tests.sort(key=lambda x: x.name) - other_tests.sort(key=lambda x: x.name) - - # Reorder the items list - items[:] = custom_logger_tests + other_tests diff --git a/tests/enterprise/litellm_enterprise/enterprise_callbacks/test_prometheus_logging_callbacks.py b/tests/enterprise/litellm_enterprise/enterprise_callbacks/test_prometheus_logging_callbacks.py deleted file mode 100644 index d0ad1cc8f8..0000000000 --- a/tests/enterprise/litellm_enterprise/enterprise_callbacks/test_prometheus_logging_callbacks.py +++ /dev/null @@ -1,2430 +0,0 @@ -import os -import sys - -sys.path.insert(0, os.path.abspath("../..")) - -import asyncio -import logging -from datetime import datetime, timedelta, timezone -from unittest.mock import MagicMock, call, patch - -import pytest -from prometheus_client import REGISTRY - -import litellm -from litellm._logging import verbose_logger -from litellm.types.utils import ( - StandardLoggingHiddenParams, - StandardLoggingMetadata, - StandardLoggingModelInformation, - StandardLoggingPayload, -) - -try: - from litellm.integrations.prometheus import ( - PrometheusLogger, - UserAPIKeyLabelValues, - get_custom_labels_from_metadata, - ) -except Exception: - PrometheusLogger = None -from litellm.proxy._types import UserAPIKeyAuth - -verbose_logger.setLevel(logging.DEBUG) - -litellm.set_verbose = True - - -@pytest.fixture -def prometheus_logger() -> PrometheusLogger: - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - return PrometheusLogger() - - -def create_standard_logging_payload() -> StandardLoggingPayload: - return StandardLoggingPayload( - id="test_id", - call_type="completion", - stream=False, - response_cost=0.1, - response_cost_failure_debug_info=None, - status="success", - total_tokens=30, - prompt_tokens=20, - completion_tokens=10, - startTime=1234567890.0, - endTime=1234567891.0, - completionStartTime=1234567890.5, - model_map_information=StandardLoggingModelInformation( - model_map_key="gpt-5-mini", model_map_value=None - ), - model="gpt-5-mini", - model_id="model-123", - model_group="openai-gpt", - custom_llm_provider="openai", - api_base="https://api.openai.com", - metadata=StandardLoggingMetadata( - user_api_key_hash="test_hash", - user_api_key_alias="test_alias", - user_api_key_team_id="test_team", - user_api_key_user_id="test_user", - user_api_key_user_email="test@example.com", - user_api_key_team_alias="test_team_alias", - user_api_key_org_id=None, - spend_logs_metadata=None, - requester_ip_address="127.0.0.1", - requester_metadata=None, - user_api_key_end_user_id="test_end_user", - ), - cache_hit=False, - cache_key=None, - saved_cache_cost=0.0, - request_tags=[], - end_user=None, - requester_ip_address="127.0.0.1", - messages=[{"role": "user", "content": "Hello, world!"}], - response={"choices": [{"message": {"content": "Hi there!"}}]}, - error_str=None, - model_parameters={"stream": True}, - hidden_params=StandardLoggingHiddenParams( - model_id="model-123", - cache_key=None, - api_base="https://api.openai.com", - response_cost="0.1", - additional_headers=None, - ), - ) - - -def test_safe_get_remaining_budget(prometheus_logger): - assert prometheus_logger._safe_get_remaining_budget(100, 30) == 70 - assert prometheus_logger._safe_get_remaining_budget(100, None) == 100 - assert prometheus_logger._safe_get_remaining_budget(None, 30) == float("inf") - assert prometheus_logger._safe_get_remaining_budget(None, None) == float("inf") - - -@pytest.mark.asyncio -async def test_async_log_success_event(prometheus_logger): - standard_logging_object = create_standard_logging_payload() - kwargs = { - "model": "gpt-5-mini", - "stream": True, - "litellm_params": { - "metadata": { - "user_api_key": "test_key", - "user_api_key_user_id": "test_user", - "user_api_key_team_id": "test_team", - "user_api_key_end_user_id": "test_end_user", - } - }, - "start_time": datetime.now(), - "completion_start_time": datetime.now(), - "api_call_start_time": datetime.now(), - "end_time": datetime.now() + timedelta(seconds=1), - "standard_logging_object": standard_logging_object, - } - response_obj = MagicMock() - - # Mock the prometheus client methods - - # High Level Metrics - request/spend - prometheus_logger.litellm_requests_metric = MagicMock() - prometheus_logger.litellm_spend_metric = MagicMock() - - # Token Metrics - prometheus_logger.litellm_tokens_metric = MagicMock() - prometheus_logger.litellm_input_tokens_metric = MagicMock() - prometheus_logger.litellm_output_tokens_metric = MagicMock() - - # Remaining Budget Metrics - prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() - prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock() - - # Virtual Key Rate limit Metrics - prometheus_logger.litellm_remaining_api_key_requests_for_model = MagicMock() - prometheus_logger.litellm_remaining_api_key_tokens_for_model = MagicMock() - - # Latency Metrics - prometheus_logger.litellm_llm_api_time_to_first_token_metric = MagicMock() - prometheus_logger.litellm_llm_api_latency_metric = MagicMock() - prometheus_logger.litellm_request_total_latency_metric = MagicMock() - - await prometheus_logger.async_log_success_event( - kwargs, response_obj, kwargs["start_time"], kwargs["end_time"] - ) - - # Assert that the metrics were incremented - prometheus_logger.litellm_requests_metric.labels.assert_called() - prometheus_logger.litellm_spend_metric.labels.assert_called() - - # Token Metrics - prometheus_logger.litellm_tokens_metric.labels.assert_called() - prometheus_logger.litellm_input_tokens_metric.labels.assert_called() - prometheus_logger.litellm_output_tokens_metric.labels.assert_called() - - # Remaining Budget Metrics - prometheus_logger.litellm_remaining_team_budget_metric.labels.assert_called() - prometheus_logger.litellm_remaining_api_key_budget_metric.labels.assert_called() - - # Virtual Key Rate limit Metrics - prometheus_logger.litellm_remaining_api_key_requests_for_model.labels.assert_called() - prometheus_logger.litellm_remaining_api_key_tokens_for_model.labels.assert_called() - - # Latency Metrics - prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels.assert_called() - prometheus_logger.litellm_llm_api_latency_metric.labels.assert_called() - prometheus_logger.litellm_request_total_latency_metric.labels.assert_called() - - -def test_increment_token_metrics(prometheus_logger): - """ - Test the increment_token_metrics method - - input, output, and total tokens metrics are incremented by the values in the standard logging payload - """ - prometheus_logger.litellm_tokens_metric = MagicMock() - prometheus_logger.litellm_input_tokens_metric = MagicMock() - prometheus_logger.litellm_output_tokens_metric = MagicMock() - - standard_logging_payload = create_standard_logging_payload() - standard_logging_payload["total_tokens"] = 100 - standard_logging_payload["prompt_tokens"] = 50 - standard_logging_payload["completion_tokens"] = 50 - - enum_values = UserAPIKeyLabelValues( - litellm_model_name=standard_logging_payload["model"], - api_provider=standard_logging_payload["custom_llm_provider"], - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - **standard_logging_payload, - ) - - prometheus_logger._increment_token_metrics( - standard_logging_payload, - end_user_id="user1", - user_api_key="key1", - user_api_key_alias="alias1", - model="gpt-5-mini", - user_api_team="team1", - user_api_team_alias="team_alias1", - user_id="user1", - enum_values=enum_values, - ) - - prometheus_logger.litellm_tokens_metric.labels.assert_called_once_with( - end_user=None, - user=None, - user_email=None, - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - requested_model=None, - model="gpt-5-mini", - model_id="model-123", - ) - prometheus_logger.litellm_tokens_metric.labels().inc.assert_called_once_with(100) - - prometheus_logger.litellm_input_tokens_metric.labels.assert_called_once_with( - end_user=None, - user=None, - user_email=None, - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - requested_model=None, - model="gpt-5-mini", - model_id="model-123", - ) - prometheus_logger.litellm_input_tokens_metric.labels().inc.assert_called_once_with( - 50 - ) - - prometheus_logger.litellm_output_tokens_metric.labels.assert_called_once_with( - end_user=None, - user=None, - user_email=None, - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - requested_model=None, - model="gpt-5-mini", - model_id="model-123", - ) - prometheus_logger.litellm_output_tokens_metric.labels().inc.assert_called_once_with( - 50 - ) - - -@pytest.mark.asyncio -async def test_increment_remaining_budget_metrics(prometheus_logger): - """ - Test the increment_remaining_budget_metrics method - - - team and api key remaining budget metrics are set to the difference between max budget and spend - - team and api key max budget metrics are set to their respective max budgets - - team and api key remaining hours metrics are set based on budget reset timestamps - """ - # Mock all budget-related metrics - prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() - prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock() - prometheus_logger.litellm_team_max_budget_metric = MagicMock() - prometheus_logger.litellm_api_key_max_budget_metric = MagicMock() - prometheus_logger.litellm_team_budget_remaining_hours_metric = MagicMock() - prometheus_logger.litellm_api_key_budget_remaining_hours_metric = MagicMock() - - # Create a future budget reset time for testing - future_reset_time_team = datetime.now() + timedelta(hours=10) - future_reset_time_key = datetime.now() + timedelta(hours=12) - # Mock the get_team_object and get_key_object functions to return objects with budget reset times - with ( - patch("litellm.proxy.auth.auth_checks.get_team_object") as mock_get_team, - patch("litellm.proxy.auth.auth_checks.get_key_object") as mock_get_key, - ): - mock_get_team.return_value = MagicMock(budget_reset_at=future_reset_time_team) - mock_get_key.return_value = MagicMock(budget_reset_at=future_reset_time_key) - - litellm_params = { - "metadata": { - "user_api_key_team_spend": 50, - "user_api_key_team_max_budget": 100, - "user_api_key_spend": 25, - "user_api_key_max_budget": 75, - } - } - - await prometheus_logger._increment_remaining_budget_metrics( - user_api_team="team1", - user_api_team_alias="team_alias1", - user_api_key="key1", - user_api_key_alias="alias1", - litellm_params=litellm_params, - response_cost=10, - ) - - # Test remaining budget metrics - prometheus_logger.litellm_remaining_team_budget_metric.labels.assert_called_once_with( - team="team1", team_alias="team_alias1" - ) - prometheus_logger.litellm_remaining_team_budget_metric.labels().set.assert_called_once_with( - 40 # 100 - (50 + 10) - ) - - prometheus_logger.litellm_remaining_api_key_budget_metric.labels.assert_called_once_with( - hashed_api_key="key1", api_key_alias="alias1" - ) - prometheus_logger.litellm_remaining_api_key_budget_metric.labels().set.assert_called_once_with( - 40 # 75 - (25 + 10) - ) - - # Test max budget metrics - prometheus_logger.litellm_team_max_budget_metric.labels.assert_called_once_with( - team="team1", team_alias="team_alias1" - ) - prometheus_logger.litellm_team_max_budget_metric.labels().set.assert_called_once_with( - 100 - ) - - prometheus_logger.litellm_api_key_max_budget_metric.labels.assert_called_once_with( - hashed_api_key="key1", api_key_alias="alias1" - ) - prometheus_logger.litellm_api_key_max_budget_metric.labels().set.assert_called_once_with( - 75 - ) - - # Test remaining hours metrics - prometheus_logger.litellm_team_budget_remaining_hours_metric.labels.assert_called_once_with( - team="team1", team_alias="team_alias1" - ) - # The remaining hours should be approximately 10 (with some small difference due to test execution time) - remaining_hours_call = prometheus_logger.litellm_team_budget_remaining_hours_metric.labels().set.call_args[ - 0 - ][ - 0 - ] - assert 9.9 <= remaining_hours_call <= 10.0 - - prometheus_logger.litellm_api_key_budget_remaining_hours_metric.labels.assert_called_once_with( - hashed_api_key="key1", api_key_alias="alias1" - ) - # The remaining hours should be approximately 10 (with some small difference due to test execution time) - remaining_hours_call = prometheus_logger.litellm_api_key_budget_remaining_hours_metric.labels().set.call_args[ - 0 - ][ - 0 - ] - assert 11.9 <= remaining_hours_call <= 12.0 - - -def test_set_latency_metrics(prometheus_logger): - """ - Test the set_latency_metrics method - - time to first token, llm api latency, and request total latency metrics are set to the values in the standard logging payload - """ - standard_logging_payload = create_standard_logging_payload() - prometheus_logger.litellm_llm_api_time_to_first_token_metric = MagicMock() - prometheus_logger.litellm_llm_api_latency_metric = MagicMock() - prometheus_logger.litellm_request_total_latency_metric = MagicMock() - - enum_values = UserAPIKeyLabelValues( - litellm_model_name=standard_logging_payload["model"], - api_provider=standard_logging_payload["custom_llm_provider"], - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - requested_model=standard_logging_payload["model_group"], - user=standard_logging_payload["metadata"]["user_api_key_user_id"], - **standard_logging_payload, - ) - - now = datetime.now() - kwargs = { - "end_time": now, # when the request ends - "start_time": now - timedelta(seconds=2), # when the request starts - "api_call_start_time": now - timedelta(seconds=1.5), # when the api call starts - "completion_start_time": now - - timedelta(seconds=1), # when the completion starts - "stream": True, - } - - prometheus_logger._set_latency_metrics( - kwargs=kwargs, - model="gpt-5-mini", - user_api_key="key1", - user_api_key_alias="alias1", - user_api_team="team1", - user_api_team_alias="team_alias1", - enum_values=enum_values, - ) - - # completion_start_time - api_call_start_time - prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels.assert_called_once_with( - end_user=None, - user="test_user", - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - requested_model="openai-gpt", - model="gpt-5-mini", - model_id="model-123", - ) - prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels().observe.assert_called_once_with( - 0.5 - ) - - # end_time - api_call_start_time - prometheus_logger.litellm_llm_api_latency_metric.labels.assert_called_once_with( - end_user=None, - user="test_user", - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - requested_model="openai-gpt", - model="gpt-5-mini", - model_id="model-123", - ) - prometheus_logger.litellm_llm_api_latency_metric.labels().observe.assert_called_once_with( - 1.5 - ) - - # total latency for the request - prometheus_logger.litellm_request_total_latency_metric.labels.assert_called_once_with( - end_user=None, - user="test_user", - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - requested_model="openai-gpt", - model="gpt-5-mini", - model_id="model-123", - ) - prometheus_logger.litellm_request_total_latency_metric.labels().observe.assert_called_once_with( - 2.0 - ) - - -def test_set_latency_metrics_missing_timestamps(prometheus_logger): - """ - Test that _set_latency_metrics handles missing timestamp values gracefully - """ - # Mock all metrics used in the method - prometheus_logger.litellm_llm_api_time_to_first_token_metric = MagicMock() - prometheus_logger.litellm_llm_api_latency_metric = MagicMock() - prometheus_logger.litellm_request_total_latency_metric = MagicMock() - - standard_logging_payload = create_standard_logging_payload() - enum_values = UserAPIKeyLabelValues( - litellm_model_name=standard_logging_payload["model"], - api_provider=standard_logging_payload["custom_llm_provider"], - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - ) - - # Test case where completion_start_time is None - kwargs = { - "end_time": datetime.now(), - "start_time": datetime.now() - timedelta(seconds=2), - "api_call_start_time": datetime.now() - timedelta(seconds=1.5), - "completion_start_time": None, # Missing completion start time - "stream": True, - } - - # This should not raise an exception - prometheus_logger._set_latency_metrics( - kwargs=kwargs, - model="gpt-5-mini", - user_api_key="key1", - user_api_key_alias="alias1", - user_api_team="team1", - user_api_team_alias="team_alias1", - enum_values=enum_values, - ) - - # Verify time to first token metric was not called due to missing completion_start_time - prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels.assert_not_called() - - # Other metrics should still be called - prometheus_logger.litellm_llm_api_latency_metric.labels.assert_called_once() - prometheus_logger.litellm_request_total_latency_metric.labels.assert_called_once() - - -def test_set_latency_metrics_missing_api_call_start(prometheus_logger): - """ - Test that _set_latency_metrics handles missing api_call_start_time gracefully - """ - # Mock all metrics used in the method - prometheus_logger.litellm_llm_api_time_to_first_token_metric = MagicMock() - prometheus_logger.litellm_llm_api_latency_metric = MagicMock() - prometheus_logger.litellm_request_total_latency_metric = MagicMock() - - standard_logging_payload = create_standard_logging_payload() - enum_values = UserAPIKeyLabelValues( - litellm_model_name=standard_logging_payload["model"], - api_provider=standard_logging_payload["custom_llm_provider"], - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - ) - - # Test case where api_call_start_time is None - kwargs = { - "end_time": datetime.now(), - "start_time": datetime.now() - timedelta(seconds=2), - "api_call_start_time": None, # Missing API call start time - "completion_start_time": datetime.now() - timedelta(seconds=1), - "stream": True, - } - - # This should not raise an exception - prometheus_logger._set_latency_metrics( - kwargs=kwargs, - model="gpt-5-mini", - user_api_key="key1", - user_api_key_alias="alias1", - user_api_team="team1", - user_api_team_alias="team_alias1", - enum_values=enum_values, - ) - - # Verify API latency metrics were not called due to missing api_call_start_time - prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels.assert_not_called() - prometheus_logger.litellm_llm_api_latency_metric.labels.assert_not_called() - - # Total request latency should still be called - prometheus_logger.litellm_request_total_latency_metric.labels.assert_called_once() - - -def test_increment_top_level_request_and_spend_metrics(prometheus_logger): - """ - Test the increment_top_level_request_and_spend_metrics method - - - litellm_requests_metric is incremented by 1 - - litellm_spend_metric is incremented by the response cost in the standard logging payload - """ - standard_logging_payload = create_standard_logging_payload() - enum_values = UserAPIKeyLabelValues( - litellm_model_name=standard_logging_payload["model"], - api_provider=standard_logging_payload["custom_llm_provider"], - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - **standard_logging_payload, - ) - prometheus_logger.litellm_requests_metric = MagicMock() - prometheus_logger.litellm_spend_metric = MagicMock() - - prometheus_logger._increment_top_level_request_and_spend_metrics( - end_user_id="user1", - user_api_key="key1", - user_api_key_alias="alias1", - model="gpt-5-mini", - user_api_team="team1", - user_api_team_alias="team_alias1", - user_id="user1", - response_cost=0.1, - enum_values=enum_values, - ) - - prometheus_logger.litellm_requests_metric.labels.assert_called_once_with( - end_user=None, - user=None, - user_email=None, - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - model="gpt-5-mini", - model_id="model-123", - api_provider="openai", - client_ip=None, - user_agent=None, - ) - prometheus_logger.litellm_requests_metric.labels().inc.assert_called_once() - - # The spend metric uses keyword arguments (same as requests metric) - prometheus_logger.litellm_spend_metric.labels.assert_called_once_with( - end_user=None, - user=None, - user_email=None, - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - model="gpt-5-mini", - model_id="model-123", - api_provider="openai", - client_ip=None, - user_agent=None, - ) - prometheus_logger.litellm_spend_metric.labels().inc.assert_called_once_with(0.1) - - -@pytest.mark.asyncio -async def test_async_log_failure_event(prometheus_logger): - # NOTE: almost all params for this metric are read from standard logging payload - standard_logging_object = create_standard_logging_payload() - kwargs = { - "model": "gpt-5-mini", - "litellm_params": { - "custom_llm_provider": "openai", - }, - "start_time": datetime.now(), - "completion_start_time": datetime.now(), - "api_call_start_time": datetime.now(), - "end_time": datetime.now() + timedelta(seconds=1), - "standard_logging_object": standard_logging_object, - "exception": Exception("Test error"), - } - response_obj = MagicMock() - - # Mock the metrics - prometheus_logger.litellm_llm_api_failed_requests_metric = MagicMock() - prometheus_logger.litellm_deployment_failure_responses = MagicMock() - prometheus_logger.litellm_deployment_total_requests = MagicMock() - prometheus_logger.set_deployment_partial_outage = MagicMock() - - await prometheus_logger.async_log_failure_event( - kwargs, response_obj, kwargs["start_time"], kwargs["end_time"] - ) - - # litellm_llm_api_failed_requests_metric incremented - # Labels: end_user, hashed_api_key, api_key_alias, model, team, team_alias, user, model_id - prometheus_logger.litellm_llm_api_failed_requests_metric.labels.assert_called_once_with( - end_user=None, - hashed_api_key="test_hash", - api_key_alias="test_alias", - model="gpt-5-mini", - team="test_team", - team_alias="test_team_alias", - user="test_user", - model_id="model-123", - ) - prometheus_logger.litellm_llm_api_failed_requests_metric.labels().inc.assert_called_once() - - # deployment should be marked in partial outage - prometheus_logger.set_deployment_partial_outage.assert_called_once_with( - litellm_model_name="gpt-5-mini", - model_id="model-123", - api_base="https://api.openai.com", - api_provider="openai", - ) - - # deployment failure responses incremented - verify key labels are populated - prometheus_logger.litellm_deployment_failure_responses.labels.assert_called_once() - actual_failure_labels = ( - prometheus_logger.litellm_deployment_failure_responses.labels.call_args.kwargs - ) - expected_failure_labels = { - "litellm_model_name": "gpt-5-mini", - "model_id": "model-123", - "api_base": "https://api.openai.com", - "api_provider": "openai", - "exception_class": "Exception", - "requested_model": "openai-gpt", - "hashed_api_key": "test_hash", - "api_key_alias": "test_alias", - "team": "test_team", - "team_alias": "test_team_alias", - } - for key, expected_val in expected_failure_labels.items(): - assert key in actual_failure_labels, f"Missing label {key}" - assert ( - actual_failure_labels[key] == expected_val - ), f"Label {key}: expected {expected_val!r}, got {actual_failure_labels[key]!r}" - assert actual_failure_labels.get("exception_status") in ("None", None) - assert actual_failure_labels.get("client_ip") == "127.0.0.1" - prometheus_logger.litellm_deployment_failure_responses.labels().inc.assert_called_once() - - # deployment total requests incremented - verify key labels are populated - prometheus_logger.litellm_deployment_total_requests.labels.assert_called_once() - actual_total_labels = ( - prometheus_logger.litellm_deployment_total_requests.labels.call_args.kwargs - ) - expected_total_labels = { - "litellm_model_name": "gpt-5-mini", - "model_id": "model-123", - "api_base": "https://api.openai.com", - "api_provider": "openai", - "requested_model": "openai-gpt", - "hashed_api_key": "test_hash", - "api_key_alias": "test_alias", - "team": "test_team", - "team_alias": "test_team_alias", - } - for key, expected_val in expected_total_labels.items(): - assert key in actual_total_labels, f"Missing label {key}" - assert ( - actual_total_labels[key] == expected_val - ), f"Label {key}: expected {expected_val!r}, got {actual_total_labels[key]!r}" - assert actual_total_labels.get("client_ip") == "127.0.0.1" - prometheus_logger.litellm_deployment_total_requests.labels().inc.assert_called_once() - - -@pytest.mark.asyncio -async def test_async_log_failure_event_litellm_side_rate_limit(prometheus_logger): - """LiteLLM-side reject (no deployment picked) routes the requested model - into `requested_model` and skips the partial-outage flag.""" - standard_logging_object = create_standard_logging_payload() - standard_logging_object["model_id"] = "" - standard_logging_object["model_group"] = "" - standard_logging_object["api_base"] = "" - - rate_limit_exc = Exception("LiteLLM rate limit exceeded") - rate_limit_exc.status_code = 429 - kwargs = { - "model": "us/azure/openai/gpt-5-mini", - "litellm_params": {}, - "start_time": datetime.now(), - "completion_start_time": datetime.now(), - "api_call_start_time": datetime.now(), - "end_time": datetime.now() + timedelta(seconds=1), - "standard_logging_object": standard_logging_object, - "exception": rate_limit_exc, - } - - prometheus_logger.litellm_llm_api_failed_requests_metric = MagicMock() - prometheus_logger.litellm_deployment_failure_responses = MagicMock() - prometheus_logger.litellm_deployment_total_requests = MagicMock() - prometheus_logger.set_deployment_partial_outage = MagicMock() - - await prometheus_logger.async_log_failure_event( - kwargs, MagicMock(), kwargs["start_time"], kwargs["end_time"] - ) - - prometheus_logger.set_deployment_partial_outage.assert_not_called() - - prometheus_logger.litellm_deployment_failure_responses.labels.assert_called_once() - actual_failure_labels = ( - prometheus_logger.litellm_deployment_failure_responses.labels.call_args.kwargs - ) - assert actual_failure_labels["requested_model"] == "us/azure/openai/gpt-5-mini" - assert actual_failure_labels["litellm_model_name"] == "" - assert actual_failure_labels["model_id"] == "" - assert actual_failure_labels["api_base"] == "" - assert actual_failure_labels["api_provider"] == "" - assert actual_failure_labels["exception_status"] == "429" - - -@pytest.mark.asyncio -async def test_async_post_call_failure_hook(prometheus_logger): - """ - Test for the async_post_call_failure_hook method - - it should increment the litellm_proxy_failed_requests_metric and litellm_proxy_total_requests_metric - """ - # Opt into the unified rate-limit labels so this test exercises the - # full label set surfaced when `prometheus_emit_rate_limit_labels` is on. - # The logger caches each metric's label set at construction time (so the - # labels passed to ``counter.labels(...)`` stay in lock step with the - # labels used to register the metric), so we must invalidate the cache - # after flipping the toggle for the cache to pick up the new label set. - original_emit = litellm.prometheus_emit_rate_limit_labels - litellm.prometheus_emit_rate_limit_labels = True - prometheus_logger._cached_metric_labels.clear() - - # Mock the prometheus metrics - prometheus_logger.litellm_proxy_failed_requests_metric = MagicMock() - prometheus_logger.litellm_proxy_total_requests_metric = MagicMock() - - # Create test data - request_data = {"model": "gpt-5-mini"} - - original_exception = litellm.RateLimitError( - message="Test error", llm_provider="openai", model="gpt-5-mini" - ) - - user_api_key_dict = UserAPIKeyAuth( - api_key="test_key", - key_alias="test_alias", - team_id="test_team", - team_alias="test_team_alias", - user_id="test_user", - end_user_id="test_end_user", - request_route="/chat/completions", - ) - - try: - # Call the function - await prometheus_logger.async_post_call_failure_hook( - request_data=request_data, - original_exception=original_exception, - user_api_key_dict=user_api_key_dict, - ) - - # Assert failed requests metric was incremented with correct labels - prometheus_logger.litellm_proxy_failed_requests_metric.labels.assert_called_once_with( - end_user=None, - user="test_user", - user_email=None, - hashed_api_key="test_key", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - requested_model="gpt-5-mini", - exception_status="429", - exception_class="Openai.RateLimitError", - rate_limit_category="vendor_rate_limit", - rate_limit_type=None, - route=user_api_key_dict.request_route, - model_id=None, - client_ip=None, - user_agent=None, - ) - finally: - litellm.prometheus_emit_rate_limit_labels = original_emit - prometheus_logger._cached_metric_labels.clear() - prometheus_logger.litellm_proxy_failed_requests_metric.labels().inc.assert_called_once() - - # Assert total requests metric was incremented with correct labels - prometheus_logger.litellm_proxy_total_requests_metric.labels.assert_called_once_with( - end_user=None, - hashed_api_key="test_key", - api_key_alias="test_alias", - requested_model="gpt-5-mini", - team="test_team", - team_alias="test_team_alias", - org_id=None, - org_alias=None, - user="test_user", - status_code="429", - user_email=None, - route=user_api_key_dict.request_route, - model_id=None, - client_ip=None, - user_agent=None, - ) - prometheus_logger.litellm_proxy_total_requests_metric.labels().inc.assert_called_once() - - -@pytest.mark.asyncio -async def test_async_post_call_success_hook(prometheus_logger): - """ - Test for the async_post_call_success_hook method - - litellm_proxy_total_requests_metric is NOT incremented here to avoid double-counting. - It is incremented in async_log_success_event instead. - """ - # Mock the prometheus metric - prometheus_logger.litellm_proxy_total_requests_metric = MagicMock() - - # Create test data - data = {"model": "gpt-5-mini"} - - user_api_key_dict = UserAPIKeyAuth( - api_key="test_key", - key_alias="test_alias", - team_id="test_team", - team_alias="test_team_alias", - user_id="test_user", - end_user_id="test_end_user", - request_route="/chat/completions", - ) - - response = {"choices": [{"message": {"content": "test response"}}]} - - # Call the function - await prometheus_logger.async_post_call_success_hook( - data=data, user_api_key_dict=user_api_key_dict, response=response - ) - - # Assert total requests metric was NOT incremented (moved to async_log_success_event) - prometheus_logger.litellm_proxy_total_requests_metric.labels.assert_not_called() - - -def test_set_llm_deployment_success_metrics(prometheus_logger): - # Mock all the metrics used in the method - prometheus_logger.litellm_remaining_requests_metric = MagicMock() - prometheus_logger.litellm_remaining_tokens_metric = MagicMock() - prometheus_logger.litellm_deployment_success_responses = MagicMock() - prometheus_logger.litellm_deployment_total_requests = MagicMock() - prometheus_logger.litellm_deployment_latency_per_output_token = MagicMock() - prometheus_logger.set_deployment_healthy = MagicMock() - prometheus_logger.litellm_overhead_latency_metric = MagicMock() - - standard_logging_payload = create_standard_logging_payload() - - standard_logging_payload["hidden_params"]["additional_headers"] = { - "x_ratelimit_remaining_requests": 123, - "x_ratelimit_remaining_tokens": 4321, - } - standard_logging_payload["model_group"] = "my_custom_model_group" - standard_logging_payload["hidden_params"]["litellm_overhead_time_ms"] = 100 - - # Create test data - request_kwargs = { - "model": "gpt-5-mini", - "litellm_params": { - "custom_llm_provider": "openai", - "metadata": {"model_info": {"id": "model-123"}}, - }, - "standard_logging_object": standard_logging_payload, - } - - enum_values = UserAPIKeyLabelValues( - requested_model=standard_logging_payload["model_group"], - litellm_model_name=standard_logging_payload["model"], - api_provider=standard_logging_payload["custom_llm_provider"], - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - **standard_logging_payload, - ) - - start_time = datetime.now() - end_time = start_time + timedelta(seconds=1) - output_tokens = 10 - - # Call the function - prometheus_logger.set_llm_deployment_success_metrics( - request_kwargs=request_kwargs, - start_time=start_time, - end_time=end_time, - output_tokens=output_tokens, - enum_values=enum_values, - ) - - # Verify remaining requests metric - prometheus_logger.litellm_remaining_requests_metric.labels.assert_called_once_with( - model_group="my_custom_model_group", # model_group / requested model from create_standard_logging_payload() - api_provider="openai", # llm provider - api_base="https://api.openai.com", # api base - litellm_model_name="gpt-5-mini", # actual model used - litellm model name - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - model_id="model-123", - ) - - prometheus_logger.litellm_remaining_requests_metric.labels().set.assert_called_once_with( - 123 - ) - - # Verify remaining tokens metric - prometheus_logger.litellm_remaining_tokens_metric.labels.assert_called_once_with( - api_base="https://api.openai.com", - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - api_provider="openai", - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - litellm_model_name="gpt-5-mini", - model_group="my_custom_model_group", - model_id="model-123", - ) - - prometheus_logger.litellm_remaining_tokens_metric.labels().set.assert_called_once_with( - 4321 - ) - - # Verify deployment healthy state - prometheus_logger.set_deployment_healthy.assert_called_once_with( - litellm_model_name="gpt-5-mini", - model_id="model-123", - api_base="https://api.openai.com", - api_provider="openai", - ) - - # Verify success responses metric - prometheus_logger.litellm_deployment_success_responses.labels.assert_called_once_with( - litellm_model_name="gpt-5-mini", - model_id="model-123", - api_base="https://api.openai.com", - api_provider="openai", - requested_model="my_custom_model_group", - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - client_ip=None, - user_agent=None, - ) - prometheus_logger.litellm_deployment_success_responses.labels().inc.assert_called_once() - - # Verify total requests metric - prometheus_logger.litellm_deployment_total_requests.labels.assert_called_once_with( - litellm_model_name="gpt-5-mini", - model_id="model-123", - api_base="https://api.openai.com", - api_provider="openai", - requested_model="my_custom_model_group", - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - client_ip=None, - user_agent=None, - ) - prometheus_logger.litellm_deployment_total_requests.labels().inc.assert_called_once() - - # Verify latency per output token metric - prometheus_logger.litellm_deployment_latency_per_output_token.labels.assert_called_once_with( - litellm_model_name="gpt-5-mini", - model_id="model-123", - api_base="https://api.openai.com", - api_provider="openai", - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - org_id=None, - org_alias=None, - ) - prometheus_logger.litellm_overhead_latency_metric.labels.assert_called_once_with( - api_base="https://api.openai.com", - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - api_provider="openai", - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - litellm_model_name="gpt-5-mini", - model_group="my_custom_model_group", - model_id="model-123", - ) - - # Calculate expected latency per token (1 second / 10 tokens = 0.1 seconds per token) - expected_latency_per_token = 0.1 - prometheus_logger.litellm_deployment_latency_per_output_token.labels().observe.assert_called_once_with( - expected_latency_per_token - ) - - -@pytest.mark.asyncio -async def test_log_success_fallback_event(prometheus_logger): - prometheus_logger.litellm_deployment_successful_fallbacks = MagicMock() - - original_model_group = "gpt-5-mini" - kwargs = { - "model": "gpt-5.5", - "metadata": { - "user_api_key_hash": "test_hash", - "user_api_key_alias": "test_alias", - "user_api_key_team_id": "test_team", - "user_api_key_team_alias": "test_team_alias", - }, - } - original_exception = litellm.RateLimitError( - message="Test error", llm_provider="openai", model="gpt-5-mini" - ) - - await prometheus_logger.log_success_fallback_event( - original_model_group=original_model_group, - kwargs=kwargs, - original_exception=original_exception, - ) - - prometheus_logger.litellm_deployment_successful_fallbacks.labels.assert_called_once_with( - requested_model=original_model_group, - fallback_model="gpt-5.5", - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - exception_status="429", - exception_class="Openai.RateLimitError", - model_id=None, - ) - prometheus_logger.litellm_deployment_successful_fallbacks.labels().inc.assert_called_once() - - -@pytest.mark.asyncio -async def test_log_failure_fallback_event(prometheus_logger): - prometheus_logger.litellm_deployment_failed_fallbacks = MagicMock() - - original_model_group = "gpt-5-mini" - kwargs = { - "model": "gpt-5.5", - "metadata": { - "user_api_key_hash": "test_hash", - "user_api_key_alias": "test_alias", - "user_api_key_team_id": "test_team", - "user_api_key_team_alias": "test_team_alias", - }, - } - original_exception = litellm.RateLimitError( - message="Test error", llm_provider="openai", model="gpt-5-mini" - ) - - await prometheus_logger.log_failure_fallback_event( - original_model_group=original_model_group, - kwargs=kwargs, - original_exception=original_exception, - ) - - prometheus_logger.litellm_deployment_failed_fallbacks.labels.assert_called_once_with( - requested_model=original_model_group, - fallback_model="gpt-5.5", - hashed_api_key="test_hash", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - exception_status="429", - exception_class="Openai.RateLimitError", - model_id=None, - ) - prometheus_logger.litellm_deployment_failed_fallbacks.labels().inc.assert_called_once() - - -def test_deployment_state_management(prometheus_logger): - prometheus_logger.litellm_deployment_state = MagicMock() - - test_params = { - "litellm_model_name": "gpt-5-mini", - "model_id": "model-123", - "api_base": "https://api.openai.com", - "api_provider": "openai", - } - - # Test set_deployment_healthy (state=0) - prometheus_logger.set_deployment_healthy(**test_params) - prometheus_logger.litellm_deployment_state.labels.assert_called_with( - litellm_model_name=test_params["litellm_model_name"], - model_id=test_params["model_id"], - api_base=test_params["api_base"], - api_provider=test_params["api_provider"], - ) - prometheus_logger.litellm_deployment_state.labels().set.assert_called_with(0) - - # Test set_deployment_partial_outage (state=1) - prometheus_logger.set_deployment_partial_outage(**test_params) - prometheus_logger.litellm_deployment_state.labels().set.assert_called_with(1) - - # Test set_deployment_complete_outage (state=2) - prometheus_logger.set_deployment_complete_outage(**test_params) - prometheus_logger.litellm_deployment_state.labels().set.assert_called_with(2) - - -def test_increment_deployment_cooled_down(prometheus_logger): - import inspect - - method_sig = inspect.signature(prometheus_logger.increment_deployment_cooled_down) - expected_label_count = len([p for p in method_sig.parameters.keys() if p != "self"]) - - mock_chain = MagicMock() - - def validating_labels(*label_values, **label_kwargs): - """Validate label count matches metric definition""" - total = len(label_values) + len(label_kwargs) - if total != expected_label_count: - raise ValueError( - f"Incorrect label count: expected {expected_label_count}, got {total}" - ) - return mock_chain - - prometheus_logger.litellm_deployment_cooled_down = MagicMock() - prometheus_logger.litellm_deployment_cooled_down.labels = MagicMock( - side_effect=validating_labels - ) - - prometheus_logger.increment_deployment_cooled_down( - litellm_model_name="gpt-5-mini", - model_id="model-123", - api_base="https://api.openai.com", - api_provider="openai", - exception_status="429", - ) - - prometheus_logger.litellm_deployment_cooled_down.labels.assert_called_once_with( - "gpt-5-mini", "model-123", "https://api.openai.com", "openai", "429" - ) - mock_chain.inc.assert_called_once() - - -@pytest.mark.parametrize("enable_end_user_cost_tracking_prometheus_only", [True, False]) -def test_prometheus_factory(monkeypatch, enable_end_user_cost_tracking_prometheus_only): - from litellm.integrations.prometheus import prometheus_label_factory - from litellm.types.integrations.prometheus import UserAPIKeyLabelValues - - monkeypatch.setattr( - "litellm.enable_end_user_cost_tracking_prometheus_only", - enable_end_user_cost_tracking_prometheus_only, - ) - - enum_values = UserAPIKeyLabelValues( - end_user="test_end_user", - hashed_api_key="test_hash", - api_key_alias="test_alias", - ) - supported_labels = ["end_user", "hashed_api_key", "api_key_alias"] - returned_dict = prometheus_label_factory( - supported_enum_labels=supported_labels, enum_values=enum_values - ) - - if enable_end_user_cost_tracking_prometheus_only is True: - assert returned_dict["end_user"] == "test_end_user" - else: - assert returned_dict["end_user"] == None - assert returned_dict["hashed_api_key"] == "test_hash" - assert returned_dict["api_key_alias"] == "test_alias" - - -def test_get_custom_labels_from_metadata(monkeypatch): - monkeypatch.setattr( - "litellm.custom_prometheus_metadata_labels", ["metadata.foo", "metadata.bar"] - ) - metadata = {"foo": "bar", "bar": "baz", "taz": "qux"} - assert get_custom_labels_from_metadata(metadata) == { - "metadata_foo": "bar", - "metadata_bar": "baz", - } - - -def test_get_custom_labels_from_metadata_tags(monkeypatch): - monkeypatch.setattr("litellm.custom_prometheus_metadata_labels", []) - metadata = {"foo": "bar", "bar": "baz", "taz": "qux"} - assert get_custom_labels_from_metadata(metadata) == {} - - -def test_get_custom_labels_from_top_level_metadata(monkeypatch): - """ - Test that get_custom_labels_from_metadata can extract fields from top-level metadata, - such as requester_ip_address, not just from nested dictionaries like requester_metadata. - """ - monkeypatch.setattr( - "litellm.custom_prometheus_metadata_labels", - ["requester_ip_address", "user_api_key_alias"], - ) - # Simulate metadata structure with top-level fields - metadata = { - "requester_ip_address": "10.48.203.20", # Top-level field - "user_api_key_alias": "TestAlias", # Top-level field - "requester_metadata": { - "nested_field": "nested_value" - }, # Nested dict (excluded) - "user_api_key_auth_metadata": { - "another_nested": "value" - }, # Nested dict (excluded) - } - result = get_custom_labels_from_metadata(metadata) - assert result == { - "requester_ip_address": "10.48.203.20", - "user_api_key_alias": "TestAlias", - } - - -def test_get_custom_labels_from_top_level_and_nested_metadata(monkeypatch): - """ - Test that get_custom_labels_from_metadata can extract fields from both top-level - and nested metadata (requester_metadata, user_api_key_auth_metadata). - """ - monkeypatch.setattr( - "litellm.custom_prometheus_metadata_labels", - [ - "requester_ip_address", # Top-level - "metadata.foo", # From requester_metadata - "metadata.bar", # From user_api_key_auth_metadata - ], - ) - # Simulate combined_metadata structure as it would appear after merging - # This is what gets passed to get_custom_labels_from_metadata - combined_metadata = { - "requester_ip_address": "10.48.203.20", # Top-level field - "foo": "bar_value", # From requester_metadata (spread) - "bar": "baz_value", # From user_api_key_auth_metadata (spread) - } - result = get_custom_labels_from_metadata(combined_metadata) - assert result == { - "requester_ip_address": "10.48.203.20", - "metadata_foo": "bar_value", - "metadata_bar": "baz_value", - } - - -async def test_async_log_success_event_with_top_level_metadata( - prometheus_logger, monkeypatch -): - """ - Test that async_log_success_event correctly extracts custom labels from top-level metadata - fields like requester_ip_address, not just from nested dictionaries. - """ - # Configure custom metadata labels to extract requester_ip_address - monkeypatch.setattr( - "litellm.custom_prometheus_metadata_labels", ["requester_ip_address"] - ) - - # Create standard logging payload with requester_ip_address at top-level metadata - standard_logging_object = create_standard_logging_payload() - standard_logging_object["metadata"]["requester_ip_address"] = "10.48.203.20" - standard_logging_object["metadata"]["requester_metadata"] = {} # Empty nested dict - standard_logging_object["metadata"][ - "user_api_key_auth_metadata" - ] = {} # Empty nested dict - - kwargs = { - "model": "gpt-5-mini", - "stream": True, - "litellm_params": { - "metadata": { - "user_api_key": "test_key", - "user_api_key_user_id": "test_user", - "user_api_key_team_id": "test_team", - "user_api_key_end_user_id": "test_end_user", - } - }, - "start_time": datetime.now(), - "completion_start_time": datetime.now(), - "api_call_start_time": datetime.now(), - "end_time": datetime.now() + timedelta(seconds=1), - "standard_logging_object": standard_logging_object, - } - response_obj = MagicMock() - - # Mock the prometheus client methods - # Create mock chain that accepts any labels (including custom labels like requester_ip_address) - def create_mock_metric(): - mock_metric = MagicMock() - mock_labels = MagicMock() - mock_metric.labels = MagicMock(return_value=mock_labels) - mock_labels.inc = MagicMock() - mock_labels.observe = MagicMock() - mock_labels.set = MagicMock() - return mock_metric - - prometheus_logger.litellm_requests_metric = create_mock_metric() - prometheus_logger.litellm_spend_metric = create_mock_metric() - prometheus_logger.litellm_tokens_metric = create_mock_metric() - prometheus_logger.litellm_input_tokens_metric = create_mock_metric() - prometheus_logger.litellm_output_tokens_metric = create_mock_metric() - prometheus_logger.litellm_remaining_team_budget_metric = create_mock_metric() - prometheus_logger.litellm_remaining_api_key_budget_metric = create_mock_metric() - prometheus_logger.litellm_remaining_user_budget_metric = create_mock_metric() - prometheus_logger.litellm_user_max_budget_metric = create_mock_metric() - prometheus_logger.litellm_user_budget_remaining_hours_metric = create_mock_metric() - prometheus_logger.litellm_remaining_api_key_requests_for_model = ( - create_mock_metric() - ) - prometheus_logger.litellm_remaining_api_key_tokens_for_model = create_mock_metric() - prometheus_logger.litellm_llm_api_time_to_first_token_metric = create_mock_metric() - prometheus_logger.litellm_llm_api_latency_metric = create_mock_metric() - prometheus_logger.litellm_request_total_latency_metric = create_mock_metric() - # Cache metrics - prometheus_logger.litellm_cache_hits_metric = create_mock_metric() - prometheus_logger.litellm_cache_misses_metric = create_mock_metric() - prometheus_logger.litellm_cached_tokens_metric = create_mock_metric() - # Deployment metrics - prometheus_logger.litellm_deployment_state = create_mock_metric() - prometheus_logger.litellm_deployment_success_responses = create_mock_metric() - prometheus_logger.litellm_deployment_total_requests = create_mock_metric() - prometheus_logger.litellm_deployment_latency_per_output_token = create_mock_metric() - prometheus_logger.litellm_remaining_requests_metric = create_mock_metric() - prometheus_logger.litellm_remaining_tokens_metric = create_mock_metric() - prometheus_logger.litellm_overhead_latency_metric = create_mock_metric() - prometheus_logger.litellm_proxy_total_requests_metric = create_mock_metric() - - await prometheus_logger.async_log_success_event( - kwargs, response_obj, kwargs["start_time"], kwargs["end_time"] - ) - - # Verify that the metrics were called with labels - # The custom labels (like requester_ip_address) should be extracted and included in the label factory - # Since we're using mocks that accept any labels, we just verify that labels() was called - # This confirms that the custom label extraction logic ran without errors - assert prometheus_logger.litellm_requests_metric.labels.called - assert prometheus_logger.litellm_spend_metric.labels.called - - # Verify that the labels() method was called with some arguments (either positional or keyword) - # This ensures the custom label extraction happened and didn't cause a "Incorrect label names" error - call_args = prometheus_logger.litellm_requests_metric.labels.call_args - assert call_args is not None - # The test passes if labels() was called successfully, which means custom labels were handled correctly - - -def test_get_custom_labels_from_tags(monkeypatch): - from litellm.integrations.prometheus import get_custom_labels_from_tags - - monkeypatch.setattr( - "litellm.custom_prometheus_tags", ["prod", "test-env", "batch.job"] - ) - tags = ["prod", "debug", "batch.job"] - result = get_custom_labels_from_tags(tags) - assert result == { - "tag_prod": "true", - "tag_test_env": "false", # not in request tags - "tag_batch_job": "true", # dot replaced with underscore - } - - -def test_get_custom_labels_from_tags_empty_config(monkeypatch): - from litellm.integrations.prometheus import get_custom_labels_from_tags - - monkeypatch.setattr("litellm.custom_prometheus_tags", []) - tags = ["prod", "debug"] - result = get_custom_labels_from_tags(tags) - assert result == {} - - -def test_get_custom_labels_from_tags_no_tags(monkeypatch): - from litellm.integrations.prometheus import get_custom_labels_from_tags - - monkeypatch.setattr("litellm.custom_prometheus_tags", ["prod", "test"]) - tags = [] - result = get_custom_labels_from_tags(tags) - assert result == { - "tag_prod": "false", - "tag_test": "false", - } - - -def test_get_custom_labels_from_tags_wildcard_patterns(monkeypatch): - """Test wildcard pattern matching for custom labels from tags.""" - from litellm.integrations.prometheus import get_custom_labels_from_tags - - # Configure tags with wildcard patterns - monkeypatch.setattr( - "litellm.custom_prometheus_tags", - [ - "User-Agent: curl/*", - "User-Agent: python-requests/*", - "Environment: prod*", - "Service: api-gateway*", - "exact-match", - ], - ) - - # Test tags that should match the wildcard patterns - tags = [ - "User-Agent: curl/7.68.0", - "User-Agent: python-requests/2.28.1", - "Environment: production", - "Service: api-gateway-v2", - "exact-match", - "other-tag", - ] - - result = get_custom_labels_from_tags(tags) - - expected = { - "tag_User_Agent__curl__": "true", # matches "User-Agent: curl/*" - "tag_User_Agent__python_requests__": "true", # matches "User-Agent: python-requests/*" - "tag_Environment__prod_": "true", # matches "Environment: prod*" - "tag_Service__api_gateway_": "true", # matches "Service: api-gateway*" - "tag_exact_match": "true", # exact match - } - - assert result == expected - - -def test_get_custom_labels_from_tags_wildcard_no_matches(monkeypatch): - """Test wildcard patterns that don't match any tags.""" - from litellm.integrations.prometheus import get_custom_labels_from_tags - - # Configure tags with wildcard patterns - monkeypatch.setattr( - "litellm.custom_prometheus_tags", - ["User-Agent: firefox/*", "Environment: dev*", "Service: web-app*"], - ) - - # Test tags that should NOT match the wildcard patterns - tags = [ - "User-Agent: curl/7.68.0", # doesn't match "User-Agent: firefox/*" - "Environment: production", # doesn't match "Environment: dev*" - "Service: api-gateway-v2", # doesn't match "Service: web-app*" - "other-tag", - ] - - result = get_custom_labels_from_tags(tags) - - expected = { - "tag_User_Agent__firefox__": "false", # no match for "User-Agent: firefox/*" - "tag_Environment__dev_": "false", # no match for "Environment: dev*" - "tag_Service__web_app_": "false", # no match for "Service: web-app*" - } - - assert result == expected - - -def test_tag_matches_wildcard_configured_pattern(): - """Test the helper function for wildcard pattern matching.""" - from litellm.integrations.prometheus import _tag_matches_wildcard_configured_pattern - - # Test cases that should match - assert ( - _tag_matches_wildcard_configured_pattern( - tags=["User-Agent: curl/7.68.0", "prod", "other"], - configured_tag="User-Agent: curl/*", - ) - is True - ) - - assert ( - _tag_matches_wildcard_configured_pattern( - tags=["User-Agent: python-requests/2.28.1", "test"], - configured_tag="User-Agent: python-requests/*", - ) - is True - ) - - assert ( - _tag_matches_wildcard_configured_pattern( - tags=["Environment: production", "debug"], - configured_tag="Environment: prod*", - ) - is True - ) - - # Test exact match (no wildcard) - assert ( - _tag_matches_wildcard_configured_pattern( - tags=["prod", "test"], configured_tag="prod" - ) - is True - ) - - # Test cases that should NOT match - assert ( - _tag_matches_wildcard_configured_pattern( - tags=["User-Agent: firefox/98.0", "prod"], - configured_tag="User-Agent: curl/*", - ) - is False - ) - - assert ( - _tag_matches_wildcard_configured_pattern( - tags=["Environment: development", "test"], - configured_tag="Environment: prod*", - ) - is False - ) - - assert ( - _tag_matches_wildcard_configured_pattern( - tags=["staging", "test"], configured_tag="prod" - ) - is False - ) - - # Test with empty tags - assert ( - _tag_matches_wildcard_configured_pattern( - tags=[], configured_tag="User-Agent: curl/*" - ) - is False - ) - - -@pytest.mark.asyncio(scope="session") -async def test_initialize_remaining_budget_metrics(prometheus_logger): - """ - Test that _initialize_remaining_budget_metrics correctly sets budget metrics for all teams - """ - litellm.prometheus_initialize_budget_metrics = True - # Mock the prisma client and get_paginated_teams function - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.management_endpoints.team_endpoints.get_paginated_teams" - ) as mock_get_teams, - ): - # Create mock team data with proper datetime objects for budget_reset_at - future_reset = datetime.now() + timedelta(hours=24) # Reset 24 hours from now - mock_teams = [ - MagicMock( - team_id="team1", - team_alias="alias1", - max_budget=100, - spend=30, - budget_reset_at=future_reset, - ), - MagicMock( - team_id="team2", - team_alias="alias2", - max_budget=200, - spend=50, - budget_reset_at=future_reset, - ), - MagicMock( - team_id="team3", - team_alias=None, - max_budget=300, - spend=100, - budget_reset_at=future_reset, - ), - ] - - # Mock get_paginated_teams to return our test data - mock_get_teams.return_value = (mock_teams, len(mock_teams)) - - # Mock the Prometheus metrics - prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() - prometheus_logger.litellm_team_budget_remaining_hours_metric = MagicMock() - - # Call the function - await prometheus_logger._initialize_remaining_budget_metrics() - - # Verify the remaining budget metric was set correctly for each team - expected_budget_calls = [ - call.labels("team1", "alias1").set(70), # 100 - 30 - call.labels("team2", "alias2").set(150), # 200 - 50 - call.labels("team3", "").set(200), # 300 - 100 - ] - - prometheus_logger.litellm_remaining_team_budget_metric.assert_has_calls( - expected_budget_calls, any_order=True - ) - - # Get all the calls made to the hours metric - hours_calls = ( - prometheus_logger.litellm_team_budget_remaining_hours_metric.mock_calls - ) - - # Verify the structure and approximate values of the hours calls - assert len(hours_calls) == 6 # 3 teams * 2 calls each (labels + set) - - # Helper function to extract hours value from call - def get_hours_from_call(call_obj): - if "set" in str(call_obj): - return call_obj[1][0] # Extract the hours value - return None - - # Verify each team's hours are approximately 24 (within reasonable bounds) - hours_values = [ - get_hours_from_call(call) - for call in hours_calls - if get_hours_from_call(call) is not None - ] - for hours in hours_values: - assert ( - 23.9 <= hours <= 24.0 - ), f"Hours value {hours} not within expected range" - - # Verify the labels were called with correct team information - label_calls = [ - call.labels(team="team1", team_alias="alias1"), - call.labels(team="team2", team_alias="alias2"), - call.labels(team="team3", team_alias=""), - ] - prometheus_logger.litellm_team_budget_remaining_hours_metric.assert_has_calls( - label_calls, any_order=True - ) - - -@pytest.mark.asyncio -async def test_initialize_remaining_budget_metrics_exception_handling( - prometheus_logger, -): - """ - Test that _initialize_remaining_budget_metrics properly handles exceptions - """ - litellm.prometheus_initialize_budget_metrics = True - # Mock the prisma client and get_paginated_teams function to raise an exception - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.management_endpoints.team_endpoints.get_paginated_teams" - ) as mock_get_teams, - patch( - "litellm.proxy.management_endpoints.key_management_endpoints._list_key_helper" - ) as mock_list_keys, - ): - # Make get_paginated_teams raise an exception - mock_get_teams.side_effect = Exception("Database error") - mock_list_keys.side_effect = Exception("Key listing error") - - # Mock prisma_client structure to raise an exception for user budget metrics - # The code accesses prisma_client.db.litellm_usertable.find_many and count - mock_usertable = MagicMock() - mock_usertable.find_many = MagicMock( - side_effect=Exception("User database error") - ) - mock_usertable.count = MagicMock(side_effect=Exception("User count error")) - - # Mock litellm_teamtable to raise an exception for team count metrics - mock_teamtable = MagicMock() - mock_teamtable.count = MagicMock(side_effect=Exception("Team count error")) - - # Mock litellm_organizationtable to raise an exception for org budget metrics - mock_orgtable = MagicMock() - mock_orgtable.find_many = MagicMock(side_effect=Exception("Org database error")) - mock_orgtable.count = MagicMock(side_effect=Exception("Org count error")) - - mock_db = MagicMock() - mock_db.litellm_usertable = mock_usertable - mock_db.litellm_teamtable = mock_teamtable - mock_db.litellm_organizationtable = mock_orgtable - mock_prisma.db = mock_db - - # Mock the Prometheus metrics - prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() - prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock() - prometheus_logger.litellm_remaining_user_budget_metric = MagicMock() - prometheus_logger.litellm_remaining_org_budget_metric = MagicMock() - prometheus_logger.litellm_total_users_metric = MagicMock() - prometheus_logger.litellm_teams_count_metric = MagicMock() - - # Mock the logger to capture the error - with patch("litellm._logging.verbose_logger.exception") as mock_logger: - # Call the function - await prometheus_logger._initialize_remaining_budget_metrics() - - # Verify all five errors were logged (teams, keys, users, orgs, and user/team count) - assert mock_logger.call_count == 5 - assert ( - "Error initializing teams budget metrics" - in mock_logger.call_args_list[0][0][0] - ) - assert ( - "Error initializing keys budget metrics" - in mock_logger.call_args_list[1][0][0] - ) - assert ( - "Error initializing users budget metrics" - in mock_logger.call_args_list[2][0][0] - ) - assert ( - "Error initializing orgs budget metrics" - in mock_logger.call_args_list[3][0][0] - ) - assert ( - "Error initializing user/team count metrics" - in mock_logger.call_args_list[4][0][0] - ) - - # Verify the metrics were never called - prometheus_logger.litellm_remaining_team_budget_metric.assert_not_called() - prometheus_logger.litellm_remaining_api_key_budget_metric.assert_not_called() - prometheus_logger.litellm_remaining_user_budget_metric.assert_not_called() - prometheus_logger.litellm_remaining_org_budget_metric.assert_not_called() - prometheus_logger.litellm_total_users_metric.assert_not_called() - prometheus_logger.litellm_teams_count_metric.assert_not_called() - - -@pytest.mark.asyncio(scope="session") -async def test_initialize_api_key_budget_metrics(prometheus_logger): - """ - Test that _initialize_api_key_budget_metrics correctly sets budget metrics for all API keys - """ - litellm.prometheus_initialize_budget_metrics = True - # Mock the prisma client and _list_key_helper function - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.management_endpoints.key_management_endpoints._list_key_helper" - ) as mock_list_keys, - ): - # Create mock key data with proper datetime objects for budget_reset_at - future_reset = datetime.now() + timedelta(hours=24) # Reset 24 hours from now - key1 = UserAPIKeyAuth( - api_key="key1_hash", - key_alias="alias1", - team_id="team1", - max_budget=100, - spend=30, - budget_reset_at=future_reset, - ) - key1.token = "key1_hash" - key2 = UserAPIKeyAuth( - api_key="key2_hash", - key_alias="alias2", - team_id="team2", - max_budget=200, - spend=50, - budget_reset_at=future_reset, - ) - key2.token = "key2_hash" - - key3 = UserAPIKeyAuth( - api_key="key3_hash", - key_alias=None, - team_id="team3", - max_budget=300, - spend=100, - budget_reset_at=future_reset, - ) - key3.token = "key3_hash" - - mock_keys = [ - key1, - key2, - key3, - ] - - # Mock _list_key_helper to return our test data - mock_list_keys.return_value = {"keys": mock_keys, "total_count": len(mock_keys)} - - # Mock the Prometheus metrics - prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock() - prometheus_logger.litellm_api_key_budget_remaining_hours_metric = MagicMock() - prometheus_logger.litellm_api_key_max_budget_metric = MagicMock() - - # Call the function - await prometheus_logger._initialize_api_key_budget_metrics() - - # Verify the remaining budget metric was set correctly for each key - expected_budget_calls = [ - call.labels("key1_hash", "alias1").set(70), # 100 - 30 - call.labels("key2_hash", "alias2").set(150), # 200 - 50 - call.labels("key3_hash", "").set(200), # 300 - 100 - ] - - prometheus_logger.litellm_remaining_api_key_budget_metric.assert_has_calls( - expected_budget_calls, any_order=True - ) - - # Get all the calls made to the hours metric - hours_calls = ( - prometheus_logger.litellm_api_key_budget_remaining_hours_metric.mock_calls - ) - - # Verify the structure and approximate values of the hours calls - assert len(hours_calls) == 6 # 3 keys * 2 calls each (labels + set) - - # Helper function to extract hours value from call - def get_hours_from_call(call_obj): - if "set" in str(call_obj): - return call_obj[1][0] # Extract the hours value - return None - - # Verify each key's hours are approximately 24 (within reasonable bounds) - hours_values = [ - get_hours_from_call(call) - for call in hours_calls - if get_hours_from_call(call) is not None - ] - for hours in hours_values: - assert ( - 23.9 <= hours <= 24.0 - ), f"Hours value {hours} not within expected range" - - # Verify max budget metric was set correctly for each key - expected_max_budget_calls = [ - call.labels("key1_hash", "alias1").set(100), - call.labels("key2_hash", "alias2").set(200), - call.labels("key3_hash", "").set(300), - ] - prometheus_logger.litellm_api_key_max_budget_metric.assert_has_calls( - expected_max_budget_calls, any_order=True - ) - - -def test_set_team_budget_metrics_multiple_teams(prometheus_logger): - """ - Test that _set_team_budget_metrics correctly handles multiple teams with different budgets and reset times - """ - # Create test teams with different budgets and reset times - teams = [ - MagicMock( - team_id="team1", - team_alias="alias1", - spend=50.0, - max_budget=100.0, - budget_reset_at=datetime(2024, 12, 31, tzinfo=timezone.utc), - ), - MagicMock( - team_id="team2", - team_alias="alias2", - spend=75.0, - max_budget=150.0, - budget_reset_at=datetime(2024, 6, 30, tzinfo=timezone.utc), - ), - MagicMock( - team_id="team3", - team_alias="alias3", - spend=25.0, - max_budget=200.0, - budget_reset_at=datetime(2024, 3, 31, tzinfo=timezone.utc), - ), - ] - - # Mock the metrics - prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() - prometheus_logger.litellm_team_max_budget_metric = MagicMock() - prometheus_logger.litellm_team_budget_remaining_hours_metric = MagicMock() - - # Set metrics for each team - for team in teams: - prometheus_logger._set_team_budget_metrics(team) - - # Verify remaining budget metric calls - expected_remaining_budget_calls = [ - call.labels(team="team1", team_alias="alias1").set(50.0), # 100 - 50 - call.labels(team="team2", team_alias="alias2").set(75.0), # 150 - 75 - call.labels(team="team3", team_alias="alias3").set(175.0), # 200 - 25 - ] - prometheus_logger.litellm_remaining_team_budget_metric.assert_has_calls( - expected_remaining_budget_calls, any_order=True - ) - - # Verify max budget metric calls - expected_max_budget_calls = [ - call.labels("team1", "alias1").set(100.0), - call.labels("team2", "alias2").set(150.0), - call.labels("team3", "alias3").set(200.0), - ] - prometheus_logger.litellm_team_max_budget_metric.assert_has_calls( - expected_max_budget_calls, any_order=True - ) - - # Verify budget reset metric calls - # Note: The exact hours will depend on the current time, so we'll just verify the structure - assert ( - prometheus_logger.litellm_team_budget_remaining_hours_metric.labels.call_count - == 3 - ) - assert ( - prometheus_logger.litellm_team_budget_remaining_hours_metric.labels().set.call_count - == 3 - ) - - -def test_set_team_budget_metrics_null_values(prometheus_logger): - """ - Test that _set_team_budget_metrics correctly handles null/None values - """ - # Create test team with null values - team = MagicMock( - team_id="team_null", - team_alias=None, # Test null alias - spend=None, # Test null spend - max_budget=None, # Test null max_budget - budget_reset_at=None, # Test null reset time - ) - - # Mock the metrics - prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() - prometheus_logger.litellm_team_max_budget_metric = MagicMock() - prometheus_logger.litellm_team_budget_remaining_hours_metric = MagicMock() - - # Set metrics for the team - prometheus_logger._set_team_budget_metrics(team) - - # Verify remaining budget metric is set to infinity when max_budget is None - prometheus_logger.litellm_remaining_team_budget_metric.labels.assert_called_once_with( - team="team_null", team_alias="" - ) - prometheus_logger.litellm_remaining_team_budget_metric.labels().set.assert_called_once_with( - float("inf") - ) - - # Verify max budget metric is not set when max_budget is None - prometheus_logger.litellm_team_max_budget_metric.assert_not_called() - - # Verify reset metric is not set when budget_reset_at is None - prometheus_logger.litellm_team_budget_remaining_hours_metric.assert_not_called() - - -def test_set_team_budget_metrics_with_custom_labels(prometheus_logger, monkeypatch): - """ - Test that _set_team_budget_metrics correctly handles custom prometheus labels - """ - # Set custom prometheus labels - custom_labels = ["metadata.organization", "metadata.environment"] - monkeypatch.setattr("litellm.custom_prometheus_metadata_labels", custom_labels) - # Logger caches each metric's label set at construction time (fixture - # runs before this monkeypatch), so invalidate so the cached label set - # picks up the freshly-configured custom metadata labels. - prometheus_logger._cached_metric_labels.clear() - - # Create test team with custom metadata - team = MagicMock( - team_id="team1", - team_alias="alias1", - spend=50.0, - max_budget=100.0, - budget_reset_at=datetime(2024, 12, 31, tzinfo=timezone.utc), - ) - - # Mock the metrics - prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() - prometheus_logger.litellm_team_max_budget_metric = MagicMock() - prometheus_logger.litellm_team_budget_remaining_hours_metric = MagicMock() - - # Set metrics for the team - prometheus_logger._set_team_budget_metrics(team) - - # Verify remaining budget metric includes custom labels - prometheus_logger.litellm_remaining_team_budget_metric.labels.assert_called_once_with( - team="team1", - team_alias="alias1", - metadata_organization=None, - metadata_environment=None, - ) - prometheus_logger.litellm_remaining_team_budget_metric.labels().set.assert_called_once_with( - 50.0 - ) # 100 - 50 - - # Verify max budget metric includes custom labels - prometheus_logger.litellm_team_max_budget_metric.labels.assert_called_once_with( - team="team1", - team_alias="alias1", - metadata_organization=None, - metadata_environment=None, - ) - prometheus_logger.litellm_team_max_budget_metric.labels().set.assert_called_once_with( - 100.0 - ) - - -def test_prometheus_label_factory_with_custom_tags(monkeypatch): - """ - Test that prometheus_label_factory correctly handles custom tags - """ - from litellm.integrations.prometheus import ( - prometheus_label_factory, - ) - from litellm.types.integrations.prometheus import UserAPIKeyLabelValues - - # Set custom tags configuration - monkeypatch.setattr("litellm.custom_prometheus_tags", ["prod", "test-env"]) - - # Create enum_values with tags - enum_values = UserAPIKeyLabelValues( - hashed_api_key="key123", - team="team1", - tags=["prod", "debug"], # Only "prod" is in our custom_prometheus_tags - ) - - # Test with supported labels including custom tags - supported_labels = ["hashed_api_key", "team", "tag_prod", "tag_test_env"] - - result = prometheus_label_factory( - supported_enum_labels=supported_labels, - enum_values=enum_values, - ) - - expected = { - "hashed_api_key": "key123", - "team": "team1", - "tag_prod": "true", # present in tags - "tag_test_env": "false", # not present in tags - } - - assert result == expected - - -def test_prometheus_label_factory_with_no_custom_tags(monkeypatch): - """ - Test that prometheus_label_factory works when no custom tags are configured - """ - from litellm.integrations.prometheus import ( - prometheus_label_factory, - ) - from litellm.types.integrations.prometheus import UserAPIKeyLabelValues - - # Set empty custom tags configuration - monkeypatch.setattr("litellm.custom_prometheus_tags", []) - - # Create enum_values with tags - enum_values = UserAPIKeyLabelValues( - hashed_api_key="key123", - team="team1", - tags=["prod", "debug"], - ) - - # Test with basic supported labels (no custom tags) - supported_labels = ["hashed_api_key", "team"] - - result = prometheus_label_factory( - supported_enum_labels=supported_labels, - enum_values=enum_values, - ) - - expected = { - "hashed_api_key": "key123", - "team": "team1", - } - - assert result == expected - - -def test_get_exception_class_name(prometheus_logger): - """ - Test that _get_exception_class_name correctly formats the exception class name - """ - # Test case 1: Exception with llm_provider - rate_limit_error = litellm.RateLimitError( - message="Rate limit exceeded", llm_provider="openai", model="gpt-5-mini" - ) - assert ( - prometheus_logger._get_exception_class_name(rate_limit_error) - == "Openai.RateLimitError" - ) - - # Test case 2: Exception with empty llm_provider - auth_error = litellm.AuthenticationError( - message="Invalid API key", llm_provider="", model="gpt-5.5" - ) - assert ( - prometheus_logger._get_exception_class_name(auth_error) == "AuthenticationError" - ) - - # Test case 3: Exception with None llm_provider - context_window_error = litellm.ContextWindowExceededError( - message="Context length exceeded", llm_provider=None, model="gpt-5.5" - ) - assert ( - prometheus_logger._get_exception_class_name(context_window_error) - == "ContextWindowExceededError" - ) - - -def test_set_llm_deployment_success_metrics_with_label_filtering(): - """ - Test that set_llm_deployment_success_metrics correctly uses prometheus_label_factory - and respects label filtering configuration to prevent "Incorrect label names" errors. - """ - from litellm.types.integrations.prometheus import PrometheusMetricsConfig - - # Create a prometheus logger with label filtering configuration - config = [ - PrometheusMetricsConfig( - group="test_group", - metrics=[ - "litellm_overhead_latency_metric", - "litellm_remaining_requests_metric", - "litellm_remaining_tokens_metric", - "litellm_deployment_success_responses", - "litellm_deployment_total_requests", - ], - include_labels=[ - "litellm_model_name", - "api_provider", - "hashed_api_key", - ], # Limited labels - ) - ] - - # Mock litellm.prometheus_metrics_config - with patch("litellm.prometheus_metrics_config", config): - # Clear registry before creating new logger - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - - prometheus_logger = PrometheusLogger() - - # Mock all the metrics used in the method - prometheus_logger.litellm_overhead_latency_metric = MagicMock() - prometheus_logger.litellm_remaining_requests_metric = MagicMock() - prometheus_logger.litellm_remaining_tokens_metric = MagicMock() - prometheus_logger.litellm_deployment_success_responses = MagicMock() - prometheus_logger.litellm_deployment_total_requests = MagicMock() - prometheus_logger.set_deployment_healthy = MagicMock() - - # Create standard logging payload - standard_logging_payload = create_standard_logging_payload() - standard_logging_payload["hidden_params"]["additional_headers"] = { - "x_ratelimit_remaining_requests": 123, - "x_ratelimit_remaining_tokens": 4321, - } - standard_logging_payload["hidden_params"]["litellm_overhead_time_ms"] = 100 - - # Create test data - request_kwargs = { - "model": "gpt-5-mini", - "litellm_params": { - "custom_llm_provider": "openai", - "metadata": {"model_info": {"id": "model-123"}}, - }, - "standard_logging_object": standard_logging_payload, - } - - enum_values = UserAPIKeyLabelValues( - litellm_model_name=standard_logging_payload["model"], - api_provider=standard_logging_payload["custom_llm_provider"], - hashed_api_key=standard_logging_payload["metadata"]["user_api_key_hash"], - api_key_alias=standard_logging_payload["metadata"]["user_api_key_alias"], - team=standard_logging_payload["metadata"]["user_api_key_team_id"], - team_alias=standard_logging_payload["metadata"]["user_api_key_team_alias"], - requested_model=standard_logging_payload["model_group"], - model=standard_logging_payload["model"], - model_id=standard_logging_payload["model_id"], - api_base=standard_logging_payload["api_base"], - ) - - start_time = datetime.now() - end_time = start_time + timedelta(seconds=1) - output_tokens = 10 - - # Call the function - this should not raise "Incorrect label names" error - prometheus_logger.set_llm_deployment_success_metrics( - request_kwargs=request_kwargs, - start_time=start_time, - end_time=end_time, - output_tokens=output_tokens, - enum_values=enum_values, - ) - - # Verify that metrics were called with filtered labels (only the configured ones) - # The exact labels depend on what get_labels_for_metric returns for each metric - - # Verify overhead latency metric was called with filtered labels - prometheus_logger.litellm_overhead_latency_metric.labels.assert_called_once() - overhead_labels = ( - prometheus_logger.litellm_overhead_latency_metric.labels.call_args[1] - ) - - # Should only contain the filtered labels that are supported for this metric - expected_filtered_labels = { - "litellm_model_name", - "api_provider", - "hashed_api_key", - } - actual_labels = set(k for k in overhead_labels.keys() if k is not None) - - # Verify that only expected labels are present (subset of configured labels) - assert actual_labels <= expected_filtered_labels - - # Verify remaining requests metric was called with filtered labels - prometheus_logger.litellm_remaining_requests_metric.labels.assert_called_once() - requests_labels = ( - prometheus_logger.litellm_remaining_requests_metric.labels.call_args[1] - ) - actual_labels = set(k for k in requests_labels.keys() if k is not None) - assert actual_labels <= expected_filtered_labels - - # Verify remaining tokens metric was called with filtered labels - prometheus_logger.litellm_remaining_tokens_metric.labels.assert_called_once() - tokens_labels = ( - prometheus_logger.litellm_remaining_tokens_metric.labels.call_args[1] - ) - actual_labels = set(k for k in tokens_labels.keys() if k is not None) - assert actual_labels <= expected_filtered_labels - - # Verify deployment success responses metric was called with filtered labels - prometheus_logger.litellm_deployment_success_responses.labels.assert_called_once() - success_labels = ( - prometheus_logger.litellm_deployment_success_responses.labels.call_args[1] - ) - actual_labels = set(k for k in success_labels.keys() if k is not None) - assert actual_labels <= expected_filtered_labels - - # Verify deployment total requests metric was called with filtered labels - prometheus_logger.litellm_deployment_total_requests.labels.assert_called_once() - total_labels = ( - prometheus_logger.litellm_deployment_total_requests.labels.call_args[1] - ) - actual_labels = set(total_labels.keys()) - assert actual_labels.issubset(expected_filtered_labels.union({None})) - - # Verify all metrics were actually called (no exceptions were raised) - prometheus_logger.litellm_overhead_latency_metric.labels().observe.assert_called_once() - prometheus_logger.litellm_remaining_requests_metric.labels().set.assert_called_once_with( - 123 - ) - prometheus_logger.litellm_remaining_tokens_metric.labels().set.assert_called_once_with( - 4321 - ) - prometheus_logger.litellm_deployment_success_responses.labels().inc.assert_called_once() - prometheus_logger.litellm_deployment_total_requests.labels().inc.assert_called_once() - - -@pytest.mark.asyncio -async def test_prometheus_token_metrics_with_prometheus_config(): - """ - Test that validates the renamed token metrics are incremented correctly with a prometheus config. - - This test ensures that after the metric renaming (git diff): - - litellm_total_tokens -> litellm_total_tokens_metric - - litellm_input_tokens -> litellm_input_tokens_metric - - litellm_output_tokens -> litellm_output_tokens_metric - - All three metrics should be properly incremented when making a successful completion request. - """ - - from litellm.types.integrations.prometheus import PrometheusMetricsConfig - - # Clear registry before test - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - - # Set up prometheus configuration that includes the token metrics - config = [ - PrometheusMetricsConfig( - group="token_metrics_test", - metrics=[ - "litellm_total_tokens_metric", - "litellm_input_tokens_metric", - "litellm_output_tokens_metric", - "litellm_requests_metric", - ], - include_labels=[ - "model", - "hashed_api_key", - "api_key_alias", - "team", - "team_alias", - ], - ) - ] - - # Mock litellm.prometheus_metrics_config - with patch("litellm.prometheus_metrics_config", config): - # Create PrometheusLogger with the configuration - prometheus_logger = PrometheusLogger() - - # Test data with specific token counts - standard_logging_payload = create_standard_logging_payload() - standard_logging_payload["total_tokens"] = 1500 - standard_logging_payload["prompt_tokens"] = 900 - standard_logging_payload["completion_tokens"] = 600 - standard_logging_payload["response_cost"] = 0.075 - - kwargs = { - "model": "gpt-5-mini", - "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": "test_key_hash", - "user_api_key_user_id": "test_user", - "user_api_key_team_id": "test_team", - "user_api_key_alias": "test_alias", - "user_api_key_team_alias": "test_team_alias", - } - }, - "start_time": datetime.now() - timedelta(seconds=2), - "completion_start_time": datetime.now() - timedelta(seconds=1), - "api_call_start_time": datetime.now() - timedelta(seconds=1.5), - "end_time": datetime.now(), - "standard_logging_object": standard_logging_payload, - } - response_obj = MagicMock() - - # Make the completion call through the logger - await prometheus_logger.async_log_success_event( - kwargs, response_obj, kwargs["start_time"], kwargs["end_time"] - ) - - await asyncio.sleep(2) - - print("final registry values", REGISTRY._collector_to_names) - - # Get metric collectors directly from registry - metric_collectors = {} - for collector, names in REGISTRY._collector_to_names.items(): - metric_name = names[0] # First name is the base metric name - metric_collectors[metric_name] = collector - - print("=== Final Metric Values (Direct Access) ===") - - # Expected values - expected_values = { - "litellm_total_tokens_metric": 1500.0, - "litellm_input_tokens_metric": 900.0, - "litellm_output_tokens_metric": 600.0, - "litellm_requests_metric": 1.0, - } - - expected_label_values = { - "api_key_alias": "test_alias", - "hashed_api_key": "test_hash", - "model": "gpt-5-mini", - "team": "test_team", - "team_alias": "test_team_alias", - } - - # Validate each metric directly - for metric_name, expected_value in expected_values.items(): - if metric_name in metric_collectors: - collector = metric_collectors[metric_name] - - # Get all samples for this metric - samples = list(collector.collect())[0].samples - - # Find the _total sample (the actual counter value) - total_sample = None - for sample in samples: - if sample.name.endswith("_total"): - total_sample = sample - break - - if total_sample: - actual_value = total_sample.value - actual_labels = total_sample.labels - - print( - f"✓ {metric_name}: expected={expected_value}, actual={actual_value}" - ) - print(f" Labels: {actual_labels}") - - # Validate the value - assert ( - actual_value == expected_value - ), f"Expected {expected_value}, got {actual_value} for {metric_name}" - - # Validate the labels - for ( - label_key, - expected_label_value, - ) in expected_label_values.items(): - actual_label_value = actual_labels.get(label_key) - assert ( - actual_label_value == expected_label_value - ), f"Expected label {label_key}={expected_label_value}, got {actual_label_value}" - - print(f" ✓ {metric_name} VALIDATED") - else: - raise AssertionError(f"No _total sample found for {metric_name}") - else: - raise AssertionError(f"Metric {metric_name} not found in registry") - - print("✓ All token metrics validated successfully!") diff --git a/tests/enterprise/litellm_enterprise/integrations/test_custom_guardrail.py b/tests/enterprise/litellm_enterprise/integrations/test_custom_guardrail.py deleted file mode 100644 index 8a29e5c1ce..0000000000 --- a/tests/enterprise/litellm_enterprise/integrations/test_custom_guardrail.py +++ /dev/null @@ -1,256 +0,0 @@ -import os -import sys - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system-path -from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.types.guardrails import GuardrailEventHooks, Mode - - -def test_custom_guardrail_with_mode_default_list(monkeypatch): - """Test Mode with default as a list of modes (e.g. default: ["pre_call", "post_call"])""" - monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", True) - cg = CustomGuardrail( - guardrail_name="test_guardrail", - supported_event_hooks=[ - GuardrailEventHooks.pre_call, - GuardrailEventHooks.post_call, - GuardrailEventHooks.logging_only, - ], - event_hook=Mode( - tags={"test_tag": "logging_only"}, - default=["pre_call", "post_call"], - ), - default_on=True, - ) - - # No tag match → default fires for pre_call - assert ( - cg.should_run_guardrail( - data={"messages": [{"role": "user", "content": "test"}]}, - event_type=GuardrailEventHooks.pre_call, - ) - is True - ) - - # No tag match → default fires for post_call - assert ( - cg.should_run_guardrail( - data={"messages": [{"role": "user", "content": "test"}]}, - event_type=GuardrailEventHooks.post_call, - ) - is True - ) - - # No tag match → logging_only NOT in default list, should not fire - assert ( - cg.should_run_guardrail( - data={"messages": [{"role": "user", "content": "test"}]}, - event_type=GuardrailEventHooks.logging_only, - ) - is False - ) - - # Tag matches → only logging_only should fire - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.logging_only, - ) - is True - ) - - # Tag matches → pre_call should NOT fire (tag says logging_only) - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.pre_call, - ) - is False - ) - - # Tag matches → post_call should NOT fire (tag says logging_only) - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.post_call, - ) - is False - ) - - -def test_custom_guardrail_with_mode_no_default(monkeypatch): - """Test Mode with no default — guardrail only fires when tag matches""" - monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", True) - cg = CustomGuardrail( - guardrail_name="test_guardrail", - supported_event_hooks=[ - GuardrailEventHooks.pre_call, - GuardrailEventHooks.logging_only, - ], - event_hook=Mode( - tags={"test_tag": "logging_only"}, - ), - default_on=True, - ) - - # No tag, no default → nothing fires - assert ( - cg.should_run_guardrail( - data={"messages": [{"role": "user", "content": "test"}]}, - event_type=GuardrailEventHooks.pre_call, - ) - is False - ) - - assert ( - cg.should_run_guardrail( - data={"messages": [{"role": "user", "content": "test"}]}, - event_type=GuardrailEventHooks.logging_only, - ) - is False - ) - - # Tag matches → only logging_only fires - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.logging_only, - ) - is True - ) - - -def test_custom_guardrail_with_mode_tag_value_list(monkeypatch): - """Test Mode with tag value as a list of modes (e.g. tags: {"tag": ["pre_call", "post_call"]})""" - monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", True) - cg = CustomGuardrail( - guardrail_name="test_guardrail", - supported_event_hooks=[ - GuardrailEventHooks.pre_call, - GuardrailEventHooks.post_call, - GuardrailEventHooks.logging_only, - ], - event_hook=Mode( - tags={"test_tag": ["pre_call", "post_call"]}, - default="logging_only", - ), - default_on=True, - ) - - # Tag matches → pre_call should fire (in tag's list) - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.pre_call, - ) - is True - ) - - # Tag matches → post_call should fire (in tag's list) - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.post_call, - ) - is True - ) - - # Tag matches → logging_only should NOT fire (not in tag's list) - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.logging_only, - ) - is False - ) - - # No tag match → default fires for logging_only - assert ( - cg.should_run_guardrail( - data={"messages": [{"role": "user", "content": "test"}]}, - event_type=GuardrailEventHooks.logging_only, - ) - is True - ) - - # No tag match → pre_call should NOT fire (not in default) - assert ( - cg.should_run_guardrail( - data={"messages": [{"role": "user", "content": "test"}]}, - event_type=GuardrailEventHooks.pre_call, - ) - is False - ) - - -def test_custom_guardrail_with_mode(monkeypatch): - monkeypatch.setattr( - "litellm.proxy.proxy_server.premium_user", True - ) # Set premium_user to True - cg = CustomGuardrail( - guardrail_name="test_guardrail", - supported_event_hooks=[ - GuardrailEventHooks.pre_call, - GuardrailEventHooks.logging_only, - ], - event_hook=Mode( - tags={"test_tag": "pre_call"}, - default="logging_only", - ), - default_on=True, - ) - - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test_message"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.pre_call, - ) - is True - ) - - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test_message"}], - }, - event_type=GuardrailEventHooks.pre_call, - ) - is False - ) - - assert ( - cg.should_run_guardrail( - data={ - "messages": [{"role": "user", "content": "test_message"}], - "litellm_metadata": {"tags": ["test_tag"]}, - }, - event_type=GuardrailEventHooks.logging_only, - ) - is False - ) diff --git a/tests/enterprise/litellm_enterprise/integrations/test_prometheus.py b/tests/enterprise/litellm_enterprise/integrations/test_prometheus.py deleted file mode 100644 index ebea96e215..0000000000 --- a/tests/enterprise/litellm_enterprise/integrations/test_prometheus.py +++ /dev/null @@ -1,1063 +0,0 @@ -""" -Mock prometheus unit tests, these don't rely on LLM API calls -""" - -import json -import os -import sys - -import pytest -from fastapi.testclient import TestClient - -sys.path.insert( - 0, os.path.abspath("../../..") -) # Adds the parent directory to the system path - -from unittest.mock import patch - -import pytest_asyncio -from apscheduler.schedulers.asyncio import AsyncIOScheduler - -# Add prometheus_client import for registry cleanup -from prometheus_client import REGISTRY - -import litellm -from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES - -try: - from litellm.integrations.prometheus import ( - PrometheusLogger, - prometheus_label_factory, - ) -except Exception: - PrometheusLogger = None - prometheus_label_factory = None -from litellm.types.integrations.prometheus import ( - PrometheusMetricLabels, - PrometheusMetricsConfig, - UserAPIKeyLabelValues, -) - - -@pytest.fixture -def prometheus_logger() -> PrometheusLogger: - """ - Fixture that creates a clean PrometheusLogger instance by clearing the registry first. - This prevents "Duplicated timeseries in CollectorRegistry" errors. - """ - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - return PrometheusLogger() - - -def clear_prometheus_registry(): - """Helper function to clear the Prometheus registry""" - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - - -def test_initialize_budget_metrics_cron_job(): - # Clear registry before test - clear_prometheus_registry() - - # Create a scheduler - scheduler = AsyncIOScheduler() - - # Create and register a PrometheusLogger - prometheus_logger = PrometheusLogger() - litellm.callbacks = [prometheus_logger] - - # Initialize the cron job - PrometheusLogger.initialize_budget_metrics_cron_job(scheduler) - - # Verify that a job was added to the scheduler - jobs = scheduler.get_jobs() - assert len(jobs) == 1 - - # Verify job properties - job = jobs[0] - assert ( - job.trigger.interval.total_seconds() / 60 - == PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES - ) - assert job.func.__name__ == "initialize_remaining_budget_metrics" - - -def test_end_user_not_tracked_for_all_prometheus_metrics(): - """ - Test that end_user is not tracked for all Prometheus metrics by default. - - This test ensures that: - 1. By default, end_user is filtered out from all Prometheus metrics - 2. Future metrics that include end_user in their label definitions will also be filtered - 3. The filtering happens through the prometheus_label_factory function - """ - # Reset any previous settings - original_setting = getattr( - litellm, "enable_end_user_cost_tracking_prometheus_only", None - ) - litellm.enable_end_user_cost_tracking_prometheus_only = None # Default behavior - - try: - # Test data with end_user present - test_end_user_id = "test_user_123" - enum_values = UserAPIKeyLabelValues( - end_user=test_end_user_id, - hashed_api_key="test_key", - api_key_alias="test_alias", - team="test_team", - team_alias="test_team_alias", - user="test_user", - requested_model="gpt-5.5", - model="gpt-5.5", - litellm_model_name="gpt-5.5", - ) - - # Get all defined Prometheus metrics that include end_user in their labels - metrics_with_end_user = [] - for metric_name in PrometheusMetricLabels.__dict__: - if not metric_name.startswith("_") and metric_name != "get_labels": - labels = getattr(PrometheusMetricLabels, metric_name) - if isinstance(labels, list) and "end_user" in labels: - metrics_with_end_user.append(metric_name) - - # Ensure we found some metrics with end_user (sanity check) - assert ( - len(metrics_with_end_user) > 0 - ), "No metrics with end_user found - test setup issue" - - # Test each metric that includes end_user in its label definition - for metric_name in metrics_with_end_user: - supported_labels = PrometheusMetricLabels.get_labels(metric_name) - - # Verify that end_user is in the supported labels (before filtering) - assert ( - "end_user" in supported_labels - ), f"end_user should be in {metric_name} labels" - - # Call prometheus_label_factory to get filtered labels - filtered_labels = prometheus_label_factory( - supported_enum_labels=supported_labels, enum_values=enum_values - ) - print("filtered labels logged on prometheus=", filtered_labels) - - # Verify that end_user is None in the filtered labels (filtered out) - assert filtered_labels.get("end_user") is None, ( - f"end_user should be None for metric {metric_name} when " - f"enable_end_user_cost_tracking_prometheus_only is not True. " - f"Got: {filtered_labels.get('end_user')}" - ) - - # Test that when enable_end_user_cost_tracking_prometheus_only is True, end_user is tracked - litellm.enable_end_user_cost_tracking_prometheus_only = True - - # Test one metric to verify end_user is now included - test_metric = metrics_with_end_user[0] - supported_labels = PrometheusMetricLabels.get_labels(test_metric) - filtered_labels = prometheus_label_factory( - supported_enum_labels=supported_labels, enum_values=enum_values - ) - - # Now end_user should be present - assert filtered_labels.get("end_user") == test_end_user_id, ( - f"end_user should be present for metric {test_metric} when " - f"enable_end_user_cost_tracking_prometheus_only is True" - ) - - finally: - # Restore original setting - litellm.enable_end_user_cost_tracking_prometheus_only = original_setting - - -def test_future_metrics_with_end_user_are_filtered(): - """ - Test that ensures future metrics that include end_user will also be filtered. - This simulates adding a new metric with end_user in its labels. - """ - # Reset setting - original_setting = getattr( - litellm, "enable_end_user_cost_tracking_prometheus_only", None - ) - litellm.enable_end_user_cost_tracking_prometheus_only = None - - try: - # Simulate a new metric that includes end_user - simulated_new_metric_labels = [ - "end_user", - "hashed_api_key", - "api_key_alias", - "model", - "team", - "new_label", # Some new label that might be added in the future - ] - - test_end_user_id = "future_test_user" - enum_values = UserAPIKeyLabelValues( - end_user=test_end_user_id, - hashed_api_key="test_key", - api_key_alias="test_alias", - team="test_team", - model="gpt-5.5", - ) - - # Test the filtering - filtered_labels = prometheus_label_factory( - supported_enum_labels=simulated_new_metric_labels, enum_values=enum_values - ) - print("filtered labels logged on prometheus=", filtered_labels) - - # Verify end_user is filtered out even for this "new" metric - assert ( - filtered_labels.get("end_user") is None - ), "end_user should be filtered out for future metrics by default" - - # Verify other labels are present - assert filtered_labels.get("hashed_api_key") == "test_key" - assert filtered_labels.get("team") == "test_team" - - finally: - # Restore original setting - litellm.enable_end_user_cost_tracking_prometheus_only = original_setting - - -def test_prometheus_config_parsing(): - """Test that prometheus metrics configuration is parsed correctly""" - # Clear registry before test - clear_prometheus_registry() - - # Set up test configuration - test_config = [ - { - "group": "service_metrics", - "metrics": [ - "litellm_deployment_failure_responses", - "litellm_deployment_total_requests", - "litellm_proxy_failed_requests_metric", - "litellm_proxy_total_requests_metric", - ], - "include_labels": [ - "requested_model", - "team", - ], - } - ] - - # Set configuration - litellm.prometheus_metrics_config = test_config - - # Create PrometheusLogger instance - logger = PrometheusLogger() - - # Parse configuration - label_filters = logger._parse_prometheus_config() - - # Verify label filters exist for each metric - expected_labels = [ - "requested_model", - "team", - ] - - expected_metrics = [ - "litellm_deployment_failure_responses", - "litellm_deployment_total_requests", - "litellm_proxy_failed_requests_metric", - "litellm_proxy_total_requests_metric", - ] - - for metric in expected_metrics: - assert metric in label_filters - assert label_filters[metric] == expected_labels - - -def test_get_metric_labels(): - """Test that metric label filtering works correctly""" - # Clear registry before test - clear_prometheus_registry() - - # Set up test configuration - test_config = [ - { - "group": "service_metrics", - "metrics": ["litellm_deployment_failure_responses"], - "include_labels": ["litellm_model_name", "api_provider"], - } - ] - - litellm.prometheus_metrics_config = test_config - - logger = PrometheusLogger() - - # Get filtered labels - labels = logger.get_labels_for_metric("litellm_deployment_failure_responses") - - # Verify only configured labels are returned - assert "litellm_model_name" in labels - assert "api_provider" in labels - # These should be filtered out even if they're in the default labels - assert ( - len([l for l in labels if l not in ["litellm_model_name", "api_provider"]]) == 0 - ) - - -def test_no_prometheus_config(): - """Test behavior when no prometheus config is set""" - # Clear registry before test - clear_prometheus_registry() - - # Clear any existing config - litellm.prometheus_metrics_config = None - - logger = PrometheusLogger() - - # Should return default labels when no config is set - labels = logger.get_labels_for_metric("litellm_deployment_failure_responses") - # Should return some labels (the default ones) - assert isinstance(labels, list) - # Should have more than 0 labels (the default ones) - assert len(labels) > 0 - - -def test_prometheus_metrics_config_type(): - """Test that PrometheusMetricsConfig type validation works""" - # Valid configuration - valid_config = PrometheusMetricsConfig( - group="service_metrics", - metrics=["litellm_deployment_failure_responses"], - include_labels=["litellm_model_name"], - ) - - assert valid_config.group == "service_metrics" - assert valid_config.metrics == ["litellm_deployment_failure_responses"] - assert valid_config.include_labels == ["litellm_model_name"] - - # Test with None include_labels (should be allowed) - config_no_labels = PrometheusMetricsConfig( - group="service_metrics", - metrics=["litellm_deployment_failure_responses"], - include_labels=None, - ) - - assert config_no_labels.include_labels is None - print("PrometheusMetricsConfig type validation passed!") - - -def test_basic_functionality(): - """Test basic functionality without creating multiple instances""" - # Clear registry before test - clear_prometheus_registry() - - # Set up test configuration - test_config = [ - { - "group": "service_metrics", - "metrics": [ - "litellm_deployment_failure_responses", - "litellm_deployment_total_requests", - ], - "include_labels": ["litellm_model_name", "api_provider"], - } - ] - - # Set configuration - litellm.prometheus_metrics_config = test_config - - # Test that the configuration is properly set - assert litellm.prometheus_metrics_config is not None - assert len(litellm.prometheus_metrics_config) == 1 - assert litellm.prometheus_metrics_config[0]["group"] == "service_metrics" - assert ( - "litellm_deployment_failure_responses" - in litellm.prometheus_metrics_config[0]["metrics"] - ) - - print("Basic prometheus configuration test passed!") - - -# ============================================================================== -# VALIDATION TESTS - Test the new validation logic for metrics and labels -# ============================================================================== - - -def test_invalid_metric_name_validation(): - """Test that invalid metric names are caught and raise ValueError""" - # Clear registry before test - clear_prometheus_registry() - - # Set up test configuration with invalid metric name - test_config = [ - { - "group": "service_metrics", - "metrics": [ - "invalid_metric_name_that_does_not_exist", - "litellm_deployment_total_requests", # valid metric - ], - "include_labels": ["litellm_model_name"], - } - ] - - litellm.prometheus_metrics_config = test_config - - # Creating PrometheusLogger should raise ValueError due to invalid metric - with pytest.raises(ValueError) as exc_info: - PrometheusLogger() - - # Verify error message contains information about invalid metric - assert "invalid_metric_name_that_does_not_exist" in str(exc_info.value) - assert "Configuration validation failed" in str(exc_info.value) - - -def test_invalid_labels_validation(): - """Test that invalid labels for metrics are caught and raise ValueError""" - # Clear registry before test - clear_prometheus_registry() - - # Set up test configuration with invalid labels - test_config = [ - { - "group": "service_metrics", - "metrics": ["litellm_deployment_total_requests"], - "include_labels": [ - "litellm_model_name", # valid label - "invalid_label_name", # invalid label - "another_invalid_label", # another invalid label - ], - } - ] - - litellm.prometheus_metrics_config = test_config - - # Creating PrometheusLogger should raise ValueError due to invalid labels - with pytest.raises(ValueError) as exc_info: - PrometheusLogger() - - # Verify error message contains information about invalid labels - assert "invalid_label_name" in str(exc_info.value) - assert "Configuration validation failed" in str(exc_info.value) - - -def test_valid_configuration_passes_validation(): - """Test that valid configuration passes validation without errors""" - # Clear registry before test - clear_prometheus_registry() - - # Set up test configuration with all valid metrics and labels - test_config = [ - { - "group": "service_metrics", - "metrics": [ - "litellm_deployment_total_requests", - "litellm_deployment_failure_responses", - ], - "include_labels": [ - "litellm_model_name", - "api_provider", - "requested_model", - ], - } - ] - - litellm.prometheus_metrics_config = test_config - - # This should not raise any exceptions - try: - logger = PrometheusLogger() - # Verify the logger was created successfully - assert logger is not None - assert hasattr(logger, "enabled_metrics") - assert "litellm_deployment_total_requests" in logger.enabled_metrics - assert "litellm_deployment_failure_responses" in logger.enabled_metrics - except Exception as e: - pytest.fail(f"Valid configuration should not raise exception: {e}") - - -# ============================================================================== -# END VALIDATION TESTS -# ============================================================================== - - -# ============================================================================== -# SEMANTIC VALIDATION TESTS - Detect logical errors in metric increments -# ============================================================================== - - -class MockCounter: - """Mock counter for testing metric increments""" - - def __init__(self, name): - self.name = name - self.labels_calls = [] - self.inc_calls = [] - - def labels(self, *args, **kwargs): - self.labels_calls.append(kwargs) - return self - - def inc(self, value=1): - self.inc_calls.append(value) - - -class MockHistogram: - """Mock histogram for testing metric observations""" - - def __init__(self, name): - self.name = name - self.labels_calls = [] - self.observe_calls = [] - - def labels(self, *args, **kwargs): - self.labels_calls.append(kwargs) - return self - - def observe(self, value): - self.observe_calls.append(value) - - -@pytest.fixture -def mock_prometheus_logger(): - """Create a PrometheusLogger with mocked metrics to test increment logic""" - from unittest.mock import patch - - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - - with patch("litellm.proxy.proxy_server.premium_user", True): - logger = PrometheusLogger() - - # Replace metrics with mocks to capture increment calls - logger.litellm_proxy_total_requests_metric = MockCounter( - "litellm_proxy_total_requests_metric" - ) - logger.litellm_tokens_metric = MockCounter("litellm_total_tokens") - logger.litellm_input_tokens_metric = MockCounter("litellm_input_tokens") - logger.litellm_output_tokens_metric = MockCounter("litellm_output_tokens") - logger.litellm_spend_metric = MockCounter("litellm_spend_metric") - logger.litellm_requests_metric = MockCounter("litellm_requests_metric") - - return logger - - -@pytest.mark.asyncio -async def test_request_counter_semantic_validation(mock_prometheus_logger): - """ - CRITICAL TEST: Validates that request counters are incremented by 1, not by token count. - This test specifically catches the bug where litellm_proxy_total_requests_metric - is incorrectly incremented by total_tokens instead of 1. - - The metric is now ONLY incremented in async_log_success_event (for both streaming - and non-streaming) to prevent double-counting. - """ - from datetime import datetime, timedelta - from unittest.mock import MagicMock - - from litellm.proxy._types import UserAPIKeyAuth - - # Test data with large token count that should NOT affect request counter - kwargs = { - "model": "gpt-5-mini", - "litellm_params": {"metadata": {}}, - "start_time": datetime.now() - timedelta(seconds=1), - "end_time": datetime.now(), - "api_call_start_time": datetime.now() - timedelta(seconds=0.5), - "standard_logging_object": { - "total_tokens": 999, # Large number - this should NOT be used for request counter - "prompt_tokens": 600, - "completion_tokens": 399, - "response_cost": 0.005, - "model_group": "gpt-5-mini", - "model_id": "test-model-id", - "api_base": "https://api.openai.com/v1", - "custom_llm_provider": "openai", - "stream": False, - "request_tags": [], - "metadata": { - "user_api_key_user_id": "test-user", - "user_api_key_hash": "test-hash", - "user_api_key_alias": "test-alias", - "user_api_key_team_id": "test-team", - "user_api_key_team_alias": "test-team-alias", - "user_api_key_user_email": "test@example.com", - }, - "hidden_params": { - "additional_headers": {}, - }, - }, - } - - # Call the success event - should increment for both streaming and non-streaming - await mock_prometheus_logger.async_log_success_event( - kwargs, None, kwargs["start_time"], kwargs["end_time"] - ) - - # CRITICAL ASSERTION: Request counter should be incremented by 1 - total_requests_metric = mock_prometheus_logger.litellm_proxy_total_requests_metric - assert ( - len(total_requests_metric.inc_calls) == 1 - ), "Request metric should be incremented once in async_log_success_event" - - # Call the post-call logging hook - should NOT increment (to prevent double-counting) - await mock_prometheus_logger.async_post_call_success_hook( - data={}, - user_api_key_dict=UserAPIKeyAuth( - end_user="test-user", - hashed_api_key="test-hash", - api_key_alias="test-alias", - team="test-team", - model="gpt-5.5", - ), - response=MagicMock(), - ) - - # CRITICAL ASSERTION: Request counter should still be 1 (not incremented again) - total_requests_metric = mock_prometheus_logger.litellm_proxy_total_requests_metric - assert ( - len(total_requests_metric.inc_calls) == 1 - ), "Request metric should not be incremented again in async_post_call_success_hook" - - # Check that ALL request counter increments are by 1 (not by token count) - for inc_value in total_requests_metric.inc_calls: - assert inc_value == 1, ( - f"SEMANTIC BUG DETECTED: Request counter incremented by {inc_value} instead of 1. " - f"This indicates the bug where request counters are incremented by token counts." - ) - - # Verify token counters ARE incremented by token counts (this should work correctly) - tokens_metric = mock_prometheus_logger.litellm_tokens_metric - assert ( - 999 in tokens_metric.inc_calls - ), "Token metric should be incremented by total_tokens (999)" - - -@pytest.mark.asyncio -async def test_multiple_requests_counter_semantics(mock_prometheus_logger): - """ - Test that demonstrates the scaling issue: with multiple requests, - request counters should scale by number of requests, not total tokens. - """ - from datetime import datetime, timedelta - - num_requests = 3 - tokens_per_request = 500 # High token count to make the bug obvious - - for i in range(num_requests): - kwargs = { - "model": "gpt-5-mini", - "litellm_params": {"metadata": {}}, - "start_time": datetime.now() - timedelta(seconds=1), - "end_time": datetime.now(), - "api_call_start_time": datetime.now() - timedelta(seconds=0.5), - "standard_logging_object": { - "total_tokens": tokens_per_request, - "prompt_tokens": tokens_per_request // 2, - "completion_tokens": tokens_per_request // 2, - "response_cost": 0.001, - "model_group": "gpt-5-mini", - "model_id": "test-model-id", - "api_base": "https://api.openai.com/v1", - "custom_llm_provider": "openai", - "stream": False, - "request_tags": [], - "metadata": { - "user_api_key_user_id": "test-user", - "user_api_key_hash": "test-hash", - "user_api_key_alias": "test-alias", - "user_api_key_team_id": "test-team", - "user_api_key_team_alias": "test-team-alias", - "user_api_key_user_email": "test@example.com", - }, - "hidden_params": { - "additional_headers": {}, - }, - }, - } - - await mock_prometheus_logger.async_log_success_event( - kwargs, None, kwargs["start_time"], kwargs["end_time"] - ) - - # Calculate total increments - total_request_increments = sum( - mock_prometheus_logger.litellm_proxy_total_requests_metric.inc_calls - ) - total_token_increments = sum(mock_prometheus_logger.litellm_tokens_metric.inc_calls) - - # CRITICAL ASSERTION: Request increments should equal number of requests - expected_total_tokens = num_requests * tokens_per_request # 3 * 500 = 1500 - - # With the bug, total_request_increments would be 1500 instead of 3 - assert total_request_increments == num_requests, ( - f"SEMANTIC BUG: Request counter total increments = {total_request_increments}, " - f"expected {num_requests}. This suggests request counters are being incremented " - f"by token counts instead of request counts." - ) - - # Token counter should correctly equal total tokens - assert ( - total_token_increments == expected_total_tokens - ), f"Token counter should sum to {expected_total_tokens}, got {total_token_increments}" - - -@pytest.mark.asyncio -async def test_streaming_request_counter_semantics(mock_prometheus_logger): - """ - Test that streaming requests are also counted correctly (by 1, not by token count) - """ - from datetime import datetime, timedelta - - kwargs = { - "model": "gpt-5-mini", - "litellm_params": {"metadata": {}}, - "start_time": datetime.now() - timedelta(seconds=1), - "end_time": datetime.now(), - "api_call_start_time": datetime.now() - timedelta(seconds=0.5), - "standard_logging_object": { - "total_tokens": 750, # High token count for streaming - "prompt_tokens": 300, - "completion_tokens": 450, - "response_cost": 0.003, - "model_group": "gpt-5-mini", - "model_id": "test-model-id", - "api_base": "https://api.openai.com/v1", - "custom_llm_provider": "openai", - "stream": True, # This is a streaming request - "request_tags": [], - "metadata": { - "user_api_key_user_id": "test-user", - "user_api_key_hash": "test-hash", - "user_api_key_alias": "test-alias", - "user_api_key_team_id": "test-team", - "user_api_key_team_alias": "test-team-alias", - "user_api_key_user_email": "test@example.com", - }, - "hidden_params": { - "additional_headers": {}, - }, - }, - } - - await mock_prometheus_logger.async_log_success_event( - kwargs, None, kwargs["start_time"], kwargs["end_time"] - ) - - # Streaming requests should also be counted as 1 request, not 750 - for ( - inc_value - ) in mock_prometheus_logger.litellm_proxy_total_requests_metric.inc_calls: - assert ( - inc_value == 1 - ), f"SEMANTIC BUG: Streaming request counter incremented by {inc_value} instead of 1" - - -def test_metric_increment_invariants(): - """ - Test invariants that should always hold for different metric types - """ - # Invariant 1: Request counters should never be incremented by large values - suspicious_request_increments = [ - 100, - 500, - 1000, - 1500, - ] # These look like token counts - for increment in suspicious_request_increments: - # If we see request counters incremented by these values, it's likely a bug - assert ( - increment > 10 - ), f"Request increment of {increment} is suspiciously large - likely a semantic bug" - - # Invariant 2: Token counters should never be incremented by 1 (unless it's a 1-token response) - # This would indicate the reverse bug (using request count for token counter) - - # Invariant 3: Cost increments should be small positive floats - reasonable_costs = [0.001, 0.01, 0.1, 1.0] - for cost in reasonable_costs: - assert 0 < cost < 100, f"Cost {cost} should be in reasonable range" - - -def test_token_counter_semantics(): - """ - Test that token counters should be incremented by actual token values, not by 1 - """ - # These are correct patterns for token counters - correct_token_increments = [50, 100, 250, 500, 1000, 2000] - - for tokens in correct_token_increments: - # Token counters should be incremented by actual token counts - assert tokens > 1, f"Token increment of {tokens} is reasonable" - - # These would be incorrect for token counters (suggests using request count for tokens) - incorrect_token_increments = [1] # Unless it's actually a 1-token response - - # This test documents the expected behavior - token counters should use token values - - -@pytest.mark.asyncio -async def test_spend_counter_semantics(mock_prometheus_logger): - """ - Test that spend counters are incremented by cost amounts, not by 1 or token counts - """ - from datetime import datetime, timedelta - - kwargs = { - "model": "gpt-5-mini", - "litellm_params": {"metadata": {}}, - "start_time": datetime.now() - timedelta(seconds=1), - "end_time": datetime.now(), - "api_call_start_time": datetime.now() - timedelta(seconds=0.5), - "standard_logging_object": { - "total_tokens": 100, - "prompt_tokens": 60, - "completion_tokens": 40, - "response_cost": 0.0015, # This should be used for spend metrics - "model_group": "gpt-5-mini", - "model_id": "test-model-id", - "api_base": "https://api.openai.com/v1", - "custom_llm_provider": "openai", - "stream": False, - "request_tags": [], - "metadata": { - "user_api_key_user_id": "test-user", - "user_api_key_hash": "test-hash", - "user_api_key_alias": "test-alias", - "user_api_key_team_id": "test-team", - "user_api_key_team_alias": "test-team-alias", - "user_api_key_user_email": "test@example.com", - }, - "hidden_params": { - "additional_headers": {}, - }, - }, - } - - await mock_prometheus_logger.async_log_success_event( - kwargs, None, kwargs["start_time"], kwargs["end_time"] - ) - - # Verify spend counter is incremented by cost amount - spend_metric = mock_prometheus_logger.litellm_spend_metric - assert len(spend_metric.inc_calls) > 0, "Spend metric should be incremented" - assert ( - 0.0015 in spend_metric.inc_calls - ), "Spend metric should be incremented by response_cost (0.0015)" - - -# ============================================================================== -# END SEMANTIC VALIDATION TESTS -# ============================================================================== - - -# ============================================================================== -# CALLBACK FAILURE METRICS TESTS -# ============================================================================== - - -def test_callback_failure_metric_increments(prometheus_logger): - """ - Test that the callback logging failure metric can be incremented. - - This tests the litellm_callback_logging_failures_metric counter. - """ - # Get initial value - initial_value = 0 - try: - initial_value = ( - prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name="S3Logger" - )._value.get() - ) - except Exception: - initial_value = 0 - - # Increment the metric - prometheus_logger.increment_callback_logging_failure(callback_name="S3Logger") - - # Verify it incremented by 1 - current_value = prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name="S3Logger" - )._value.get() - - assert ( - current_value == initial_value + 1 - ), f"Expected callback failure metric to increment by 1, got {current_value - initial_value}" - - # Increment again for different callback - prometheus_logger.increment_callback_logging_failure(callback_name="LangFuseLogger") - - langfuse_value = prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name="LangFuseLogger" - )._value.get() - - assert langfuse_value == 1, "LangFuseLogger metric should be 1" - - # S3Logger should still be initial + 1 - s3_value = prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name="S3Logger" - )._value.get() - assert s3_value == initial_value + 1, "S3Logger metric should not change" - - print( - f"✓ Callback failure metric test passed: S3Logger={s3_value}, LangFuseLogger={langfuse_value}" - ) - - -def test_callback_failure_metric_different_callbacks(prometheus_logger): - """ - Test that different callbacks are tracked separately with their own labels. - """ - callbacks_to_test = [ - "S3Logger", - "LangFuseLogger", - "DataDogLogger", - "CustomCallback", - ] - - for callback_name in callbacks_to_test: - # Get initial value - initial = 0 - try: - initial = prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name=callback_name - )._value.get() - except Exception: - initial = 0 - - # Increment - prometheus_logger.increment_callback_logging_failure( - callback_name=callback_name - ) - - # Verify incremented - current = prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name=callback_name - )._value.get() - - assert current == initial + 1, f"{callback_name} should increment by 1" - - print( - f"✓ Multiple callback tracking test passed for {len(callbacks_to_test)} callbacks" - ) - - -@pytest.mark.asyncio -async def test_langfuse_callback_failure_metric(prometheus_logger): - """ - Test that Langfuse callback failures are properly tracked in Prometheus metrics. - - This test verifies that when Langfuse logging fails, the - litellm_callback_logging_failures_metric is incremented with callback_name="langfuse". - """ - from unittest.mock import MagicMock, patch - - from litellm.integrations.langfuse.langfuse_prompt_management import ( - LangfusePromptManagement, - ) - - # Get initial value - initial_value = 0 - try: - initial_value = prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name="langfuse" - )._value.get() - except Exception: - initial_value = 0 - - # Create Langfuse logger with mocked initialization - with patch("litellm.integrations.langfuse.langfuse_prompt_management.langfuse_client_init"): - langfuse_logger = LangfusePromptManagement() - - # Mock the log_event_on_langfuse to raise an exception - with patch( - "litellm.integrations.langfuse.langfuse_prompt_management.LangFuseHandler.get_langfuse_logger_for_request" - ) as mock_get_logger: - mock_logger = MagicMock() - mock_logger.log_event_on_langfuse.side_effect = Exception("Langfuse API error") - mock_get_logger.return_value = mock_logger - - # Mock handle_callback_failure to track calls - with patch.object(prometheus_logger, "increment_callback_logging_failure") as mock_increment: - # Inject prometheus logger into the langfuse logger - langfuse_logger.handle_callback_failure = lambda callback_name: mock_increment( - callback_name=callback_name - ) - - # Call async_log_success_event - should catch exception and increment metric - await langfuse_logger.async_log_success_event( - kwargs={}, - response_obj={}, - start_time=None, - end_time=None, - ) - - # Verify that increment was called with correct callback name - mock_increment.assert_called_once_with(callback_name="langfuse") - - print("✓ Langfuse callback failure metric test passed") - - -@pytest.mark.asyncio -async def test_langfuse_otel_callback_failure_metric(prometheus_logger): - """ - Test that Langfuse OTEL callback failures are properly tracked in Prometheus metrics. - - This test verifies that when Langfuse OTEL logging fails, the - litellm_callback_logging_failures_metric is incremented with callback_name="langfuse_otel". - """ - from unittest.mock import MagicMock, patch - - from litellm.integrations.langfuse.langfuse_otel import LangfuseOtelLogger - - # Get initial value - initial_value = 0 - try: - initial_value = prometheus_logger.litellm_callback_logging_failures_metric.labels( - callback_name="langfuse_otel" - )._value.get() - except Exception: - initial_value = 0 - - # Create Langfuse OTEL logger with mocked initialization - with patch("litellm.integrations.opentelemetry.OpenTelemetry.__init__", return_value=None): - langfuse_otel_logger = LangfuseOtelLogger(callback_name="langfuse_otel") - langfuse_otel_logger.callback_name = "langfuse_otel" - - # Mock handle_callback_failure to track calls - with patch.object(prometheus_logger, "increment_callback_logging_failure") as mock_increment: - # Inject prometheus logger into the langfuse otel logger - langfuse_otel_logger.handle_callback_failure = lambda callback_name: mock_increment( - callback_name=callback_name - ) - - # Test that the OpenTelemetry base class set_attributes exception handler works - # This is where langfuse_otel failures are caught and tracked - with patch.object(langfuse_otel_logger, "set_attributes") as mock_set_attributes: - # Simulate the exception handling in set_attributes - def set_attributes_with_error(*args, **kwargs): - # This simulates what happens in the real set_attributes method - try: - raise Exception("Attribute error") - except Exception as e: - langfuse_otel_logger.handle_callback_failure(callback_name=langfuse_otel_logger.callback_name) - - mock_set_attributes.side_effect = set_attributes_with_error - - # Call set_attributes - try: - langfuse_otel_logger.set_attributes( - span=MagicMock(), - kwargs={}, - response_obj={} - ) - except Exception: - pass - - # Verify that increment was called with correct callback name - mock_increment.assert_called_with(callback_name="langfuse_otel") - - print("✓ Langfuse OTEL callback failure metric test passed") - - -# ============================================================================== -# END CALLBACK FAILURE METRICS TESTS -# ============================================================================== diff --git a/tests/enterprise/litellm_enterprise/integrations/test_prometheus_unit_tests.py b/tests/enterprise/litellm_enterprise/integrations/test_prometheus_unit_tests.py deleted file mode 100644 index 55c4cbae82..0000000000 --- a/tests/enterprise/litellm_enterprise/integrations/test_prometheus_unit_tests.py +++ /dev/null @@ -1,329 +0,0 @@ -from unittest.mock import patch - -import pytest_asyncio -from prometheus_client import REGISTRY - -try: - from litellm.integrations.prometheus import PrometheusLogger -except Exception: - PrometheusLogger = None - -import asyncio -import sys - -from dotenv import load_dotenv - -load_dotenv() -import os - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system-path -from unittest.mock import MagicMock, patch - -import pytest - -import litellm -from litellm import Router -from litellm.caching.caching import DualCache -from litellm.router_strategy.budget_limiter import RouterBudgetLimiting -from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback -from litellm.types.router import ModelInfo -from litellm.types.utils import BudgetConfig, GenericBudgetConfigType - - -def compare_metrics(func): - def get_metrics(): - metrics = {} - for metric in REGISTRY.collect(): - for sample in metric.samples: - metrics[sample.name] = sample.value - return metrics - - async def wrapper(*args, **kwargs): - initial_metrics = get_metrics() - await func(*args, **kwargs) - await asyncio.sleep(2) - updated_metrics = get_metrics() - - return { - metric: updated_metrics.get(metric, 0) - initial_metrics.get(metric, 0) - for metric in set(initial_metrics) | set(updated_metrics) - } - - return wrapper - - -@pytest.fixture(scope="function") -def prometheus_logger(): - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - - with patch("litellm.proxy.proxy_server.premium_user", True): - logger = PrometheusLogger() - - # Add the missing async_logging_hook method - async def async_logging_hook(kwargs, result, call_type): - return kwargs, result - - logger.async_logging_hook = async_logging_hook - return logger - - -@pytest.mark.asyncio -async def test_async_prometheus_success_logging_with_callbacks(prometheus_logger): - litellm.callbacks = [prometheus_logger] - - @compare_metrics - async def op(): - await litellm.acompletion( - model="claude-haiku-4-5-20251001", - messages=[{"role": "user", "content": "what llm are u"}], - max_tokens=10, - mock_response="hi", - temperature=0.2, - ) - - diff = await op() - await asyncio.sleep(2) - assert diff["litellm_requests_metric_total"] == 1.0 - - -@pytest.mark.asyncio -async def test_async_prometheus_budget_logging_with_callbacks(prometheus_logger): - litellm.callbacks = [prometheus_logger] - - @compare_metrics - async def op(): - provider_budget_config: GenericBudgetConfigType = { - "openai": BudgetConfig(time_period="1d", budget_limit=50), - } - - router = litellm.Router( - model_list=[ - { - "model_name": "gpt-5-mini", - "litellm_params": { - "model": "openai/gpt-5-mini", - "api_key": "mock-key", - }, - } - ], - provider_budget_config=provider_budget_config, - ) - - await router.acompletion( - model="gpt-5-mini", - messages=[{"role": "user", "content": "llm?"}], - mock_response="openai", - metadata={ - "user_api_key_team_id": "team-1", - "user_api_key_team_alias": "test-team", - "user_api_key": "test-key", - "user_api_key_alias": "test-key-alias", - }, - ) - - diff = await op() - - # TODO: Should implement `litellm_provider_remaining_budget_metric` in prometheus.py - assert diff.get("litellm_provider_remaining_budget_metric", 50.0) == 50.0 - - -@pytest.mark.asyncio -async def test_prometheus_metric_tracking(): - """ - Test that the Prometheus metric for provider budget is tracked correctly - """ - try: - from unittest.mock import MagicMock - - from litellm.integrations.prometheus import PrometheusLogger - except Exception: - PrometheusLogger = None - if PrometheusLogger is None: - pytest.skip("PrometheusLogger is not installed") - - # Create a mock PrometheusLogger - mock_prometheus = MagicMock(spec=PrometheusLogger) - - # Setup provider budget limiting - provider_budget = RouterBudgetLimiting( - dual_cache=DualCache(), - provider_budget_config={ - "openai": BudgetConfig(budget_duration="1d", max_budget=100) - }, - ) - - litellm._async_success_callback = [mock_prometheus] - - provider_budget_config: GenericBudgetConfigType = { - "openai": BudgetConfig(budget_duration="1d", max_budget=0.000000000001), - "azure": BudgetConfig(budget_duration="1d", max_budget=100), - } - - router = Router( - model_list=[ - { - "model_name": "gpt-5-mini", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/gpt-4.1-mini", - "api_key": os.getenv("AZURE_AI_API_KEY"), - "api_version": os.getenv("AZURE_AI_API_VERSION"), - "api_base": os.getenv("AZURE_AI_API_BASE"), - }, - "model_info": {"id": "azure-model-id"}, - }, - { - "model_name": "gpt-5-mini", # openai model name - "litellm_params": { - "model": "openai/gpt-5-mini", - }, - "model_info": {"id": "openai-model-id"}, - }, - ], - provider_budget_config=provider_budget_config, - redis_host=os.getenv("REDIS_HOST"), - redis_port=int(os.getenv("REDIS_PORT", 6379)), - redis_password=os.getenv("REDIS_PASSWORD"), - ) - - try: - response = await router.acompletion( - messages=[{"role": "user", "content": "Hello, how are you?"}], - model="openai/gpt-5-mini", - mock_response="hi", - ) - print(response) - except Exception as e: - print("error", e) - - await asyncio.sleep(2.5) - - # Verify the mock was called correctly - mock_prometheus.track_provider_remaining_budget.assert_called() - - -class CustomPrometheusLogger(PrometheusLogger): - def __init__(self): - super().__init__() - self.deployment_complete_outages = [] - self.deployment_cooled_downs = [] - - def set_deployment_complete_outage( - self, - litellm_model_name: str, - model_id: str, - api_base: str, - api_provider: str, - ): - self.deployment_complete_outages.append( - [litellm_model_name, model_id, api_base, api_provider] - ) - - def increment_deployment_cooled_down( - self, - litellm_model_name: str, - model_id: str, - api_base: str, - api_provider: str, - exception_status: str, - ): - self.deployment_cooled_downs.append( - [litellm_model_name, model_id, api_base, api_provider, exception_status] - ) - - -@pytest.mark.asyncio -async def test_router_cooldown_event_callback(): - # Clear Prometheus registry to avoid duplicate metric registration - from prometheus_client import REGISTRY - - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - - """ - Test the router_cooldown_event_callback function - - Ensures that the router_cooldown_event_callback function correctly logs the cooldown event to the PrometheusLogger - """ - # Mock Router instance - mock_router = MagicMock() - mock_deployment = { - "litellm_params": {"model": "gpt-5-mini"}, - "model_name": "gpt-5-mini", - "model_info": ModelInfo(id="test-model-id"), - } - mock_router.get_deployment.return_value = mock_deployment - - # Create a real PrometheusLogger instance - prometheus_logger = CustomPrometheusLogger() - litellm.callbacks = [prometheus_logger] - - await router_cooldown_event_callback( - litellm_router_instance=mock_router, - deployment_id="test-deployment", - exception_status="429", - cooldown_time=60.0, - ) - - await asyncio.sleep(0.5) - - # Assert that the router's get_deployment method was called - mock_router.get_deployment.assert_called_once_with(model_id="test-deployment") - - print( - "prometheus_logger.deployment_complete_outages", - prometheus_logger.deployment_complete_outages, - ) - print( - "prometheus_logger.deployment_cooled_downs", - prometheus_logger.deployment_cooled_downs, - ) - - # Assert that PrometheusLogger methods were called - assert len(prometheus_logger.deployment_complete_outages) == 1 - assert len(prometheus_logger.deployment_cooled_downs) == 1 - - assert prometheus_logger.deployment_complete_outages[0] == [ - "gpt-5-mini", - "test-model-id", - "https://api.openai.com", - "openai", - ] - assert prometheus_logger.deployment_cooled_downs[0] == [ - "gpt-5-mini", - "test-model-id", - "https://api.openai.com", - "openai", - "429", - ] - - -@pytest.mark.asyncio -async def test_router_cooldown_event_callback_no_prometheus(): - """ - Test the router_cooldown_event_callback function - - Ensures that the router_cooldown_event_callback function does not raise an error when no PrometheusLogger is found - """ - # Mock Router instance - mock_router = MagicMock() - mock_deployment = { - "litellm_params": {"model": "gpt-5-mini"}, - "model_name": "gpt-5-mini", - "model_info": ModelInfo(id="test-model-id"), - } - mock_router.get_deployment.return_value = mock_deployment - - await router_cooldown_event_callback( - litellm_router_instance=mock_router, - deployment_id="test-deployment", - exception_status="429", - cooldown_time=60.0, - ) - - # Assert that the router's get_deployment method was called - mock_router.get_deployment.assert_called_once_with(model_id="test-deployment") diff --git a/tests/enterprise/litellm_enterprise/proxy/auth/test_route_checks.py b/tests/enterprise/litellm_enterprise/proxy/auth/test_route_checks.py deleted file mode 100644 index c147c7aae9..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/auth/test_route_checks.py +++ /dev/null @@ -1,375 +0,0 @@ -import os -import sys -from unittest.mock import MagicMock, patch - -sys.path.insert( - 0, os.path.abspath("../../../..") -) # Adds the parent directory to the system path - -import pytest -from fastapi import HTTPException - -# Import the enterprise route checks -from litellm_enterprise.proxy.auth.route_checks import EnterpriseRouteChecks - - -@patch("litellm.proxy.proxy_server.premium_user", True) -class TestEnterpriseRouteChecks: - - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_should_call_route_management_disabled( - self, mock_is_management_route, mock_is_management_disabled - ): - """Test that should_call_route raises HTTPException when management routes are disabled and route is a management route""" - - # Mock the methods to return True (route is management route and management is disabled) - mock_is_management_route.return_value = True - mock_is_management_disabled.return_value = True - - # Test that calling should_call_route raises HTTPException with 403 status - with pytest.raises(HTTPException) as exc_info: - EnterpriseRouteChecks.should_call_route("/config/update") - - # Verify the exception has correct status and message - assert exc_info.value.status_code == 403 - assert "Management routes are disabled for this instance." in str( - exc_info.value.detail - ) - - @patch.object(EnterpriseRouteChecks, "is_llm_api_route_disabled") - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_llm_api_route") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_should_call_route_llm_api_disabled( - self, - mock_is_management_route, - mock_is_llm_api_route, - mock_is_management_disabled, - mock_is_llm_api_disabled, - ): - """Test that should_call_route raises HTTPException when LLM API routes are disabled and route is an LLM API route""" - - # Mock the methods - not a management route but is an LLM API route that's disabled - mock_is_management_route.return_value = False - mock_is_llm_api_route.return_value = True - mock_is_management_disabled.return_value = False - mock_is_llm_api_disabled.return_value = True - - # Test that calling should_call_route raises HTTPException with 403 status - with pytest.raises(HTTPException) as exc_info: - EnterpriseRouteChecks.should_call_route("/v1/chat/completions") - - # Verify the exception has correct status and message - assert exc_info.value.status_code == 403 - assert "LLM API routes are disabled for this instance." in str( - exc_info.value.detail - ) - - @patch.object(EnterpriseRouteChecks, "is_llm_api_route_disabled") - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_llm_api_route") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_should_call_route_both_disabled_management_takes_priority( - self, - mock_is_management_route, - mock_is_llm_api_route, - mock_is_management_disabled, - mock_is_llm_api_disabled, - ): - """Test that management route check takes priority when both are disabled""" - - # Mock the methods - route is both management and LLM API, and both are disabled - mock_is_management_route.return_value = True - mock_is_llm_api_route.return_value = True - mock_is_management_disabled.return_value = True - mock_is_llm_api_disabled.return_value = True - - # Test that calling should_call_route raises HTTPException with management route message - with pytest.raises(HTTPException) as exc_info: - EnterpriseRouteChecks.should_call_route("/config/update") - - # Verify the exception has correct status and management route message (not LLM API message) - assert exc_info.value.status_code == 403 - assert "Management routes are disabled for this instance." in str( - exc_info.value.detail - ) - - @patch.object(EnterpriseRouteChecks, "is_llm_api_route_disabled") - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_llm_api_route") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_should_call_route_enabled( - self, - mock_is_management_route, - mock_is_llm_api_route, - mock_is_management_disabled, - mock_is_llm_api_disabled, - ): - """Test that should_call_route succeeds when routes are enabled""" - - # Test case 1: Management route enabled - mock_is_management_route.return_value = True - mock_is_llm_api_route.return_value = False - mock_is_management_disabled.return_value = False - mock_is_llm_api_disabled.return_value = False - - # Should not raise exception - EnterpriseRouteChecks.should_call_route("/config/update") - - # Test case 2: LLM API route enabled - mock_is_management_route.return_value = False - mock_is_llm_api_route.return_value = True - mock_is_management_disabled.return_value = False - mock_is_llm_api_disabled.return_value = False - - # Should not raise exception - EnterpriseRouteChecks.should_call_route("/v1/chat/completions") - - # Test case 3: Neither management nor LLM API route - mock_is_management_route.return_value = False - mock_is_llm_api_route.return_value = False - mock_is_management_disabled.return_value = ( - True # These can be True since route doesn't match - ) - mock_is_llm_api_disabled.return_value = True - - # Should not raise exception - EnterpriseRouteChecks.should_call_route("/health") - - @patch.object(EnterpriseRouteChecks, "is_llm_api_route_disabled") - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_llm_api_route") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_should_call_route_management_disabled_llm_enabled( - self, - mock_is_management_route, - mock_is_llm_api_route, - mock_is_management_disabled, - mock_is_llm_api_disabled, - ): - """Test that LLM API routes work when only management routes are disabled""" - - # Mock the methods - LLM API route but management disabled, LLM API enabled - mock_is_management_route.return_value = False - mock_is_llm_api_route.return_value = True - mock_is_management_disabled.return_value = True - mock_is_llm_api_disabled.return_value = False - - # Should not raise exception since LLM API routes are enabled - EnterpriseRouteChecks.should_call_route("/v1/chat/completions") - - @patch.object(EnterpriseRouteChecks, "is_llm_api_route_disabled") - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_llm_api_route") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_should_call_route_llm_disabled_management_enabled( - self, - mock_is_management_route, - mock_is_llm_api_route, - mock_is_management_disabled, - mock_is_llm_api_disabled, - ): - """Test that management routes work when only LLM API routes are disabled""" - - # Mock the methods - management route but LLM API disabled, management enabled - mock_is_management_route.return_value = True - mock_is_llm_api_route.return_value = False - mock_is_management_disabled.return_value = False - mock_is_llm_api_disabled.return_value = True - - # Should not raise exception since management routes are enabled - EnterpriseRouteChecks.should_call_route("/config/update") - - -@patch("litellm.proxy.proxy_server.premium_user", True) -class TestEnterpriseRouteChecksModelListExemption: - """Test that /models and /v1/models are exempt from DISABLE_LLM_API_ENDPOINTS""" - - @patch.object(EnterpriseRouteChecks, "is_llm_api_route_disabled") - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_llm_api_route") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_models_route_allowed_when_llm_api_disabled( - self, - mock_is_management_route, - mock_is_llm_api_route, - mock_is_management_disabled, - mock_is_llm_api_disabled, - ): - """Test that /models is allowed even when LLM API routes are disabled""" - mock_is_management_route.return_value = False - mock_is_llm_api_route.return_value = True - mock_is_management_disabled.return_value = False - mock_is_llm_api_disabled.return_value = True - - # Should not raise exception for /models - EnterpriseRouteChecks.should_call_route("/models") - - @patch.object(EnterpriseRouteChecks, "is_llm_api_route_disabled") - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_llm_api_route") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_v1_models_route_allowed_when_llm_api_disabled( - self, - mock_is_management_route, - mock_is_llm_api_route, - mock_is_management_disabled, - mock_is_llm_api_disabled, - ): - """Test that /v1/models is allowed even when LLM API routes are disabled""" - mock_is_management_route.return_value = False - mock_is_llm_api_route.return_value = True - mock_is_management_disabled.return_value = False - mock_is_llm_api_disabled.return_value = True - - # Should not raise exception for /v1/models - EnterpriseRouteChecks.should_call_route("/v1/models") - - @patch.object(EnterpriseRouteChecks, "is_llm_api_route_disabled") - @patch.object(EnterpriseRouteChecks, "is_management_routes_disabled") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_llm_api_route") - @patch("litellm.proxy.auth.route_checks.RouteChecks.is_management_route") - def test_chat_completions_still_blocked_when_llm_api_disabled( - self, - mock_is_management_route, - mock_is_llm_api_route, - mock_is_management_disabled, - mock_is_llm_api_disabled, - ): - """Test that non-exempt LLM routes like /v1/chat/completions are still blocked""" - mock_is_management_route.return_value = False - mock_is_llm_api_route.return_value = True - mock_is_management_disabled.return_value = False - mock_is_llm_api_disabled.return_value = True - - with pytest.raises(HTTPException) as exc_info: - EnterpriseRouteChecks.should_call_route("/v1/chat/completions") - - assert exc_info.value.status_code == 403 - assert "LLM API routes are disabled for this instance." in str( - exc_info.value.detail - ) - - -@patch("litellm.proxy.proxy_server.premium_user", True) -class TestEnterpriseRouteChecksMcpManagement: - """Regression tests: MCP management routes (/v1/mcp/server*) must remain - reachable when DISABLE_LLM_API_ENDPOINTS is set on admin nodes, but must be - blocked when DISABLE_ADMIN_ENDPOINTS is set. Uses the real is_llm_api_route - / is_management_route classifiers (not mocks).""" - - @pytest.mark.parametrize( - "route", - [ - "/v1/mcp/server", - "/v1/mcp/server/abc-123", - "/v1/mcp/server/abc-123/approve", - ], - ) - def test_mcp_management_allowed_when_llm_api_disabled(self, route): - with patch.dict(os.environ, {"DISABLE_LLM_API_ENDPOINTS": "true"}, clear=False): - os.environ.pop("DISABLE_ADMIN_ENDPOINTS", None) - # Should not raise — MCP management is a management route, not llm_api. - EnterpriseRouteChecks.should_call_route(route) - - @pytest.mark.parametrize( - "route", - [ - "/v1/mcp/server", - "/v1/mcp/server/abc-123", - ], - ) - def test_mcp_management_blocked_when_admin_disabled(self, route): - with patch.dict(os.environ, {"DISABLE_ADMIN_ENDPOINTS": "true"}, clear=False): - os.environ.pop("DISABLE_LLM_API_ENDPOINTS", None) - with pytest.raises(HTTPException) as exc_info: - EnterpriseRouteChecks.should_call_route(route) - - assert exc_info.value.status_code == 403 - assert "Management routes are disabled for this instance." in str( - exc_info.value.detail - ) - - @pytest.mark.parametrize( - "route", - [ - "/mcp/tools/call", - "/mcp-rest/tools/call", - ], - ) - def test_mcp_inference_still_blocked_when_llm_api_disabled(self, route): - with patch.dict(os.environ, {"DISABLE_LLM_API_ENDPOINTS": "true"}, clear=False): - os.environ.pop("DISABLE_ADMIN_ENDPOINTS", None) - with pytest.raises(HTTPException) as exc_info: - EnterpriseRouteChecks.should_call_route(route) - - assert exc_info.value.status_code == 403 - assert "LLM API routes are disabled for this instance." in str( - exc_info.value.detail - ) - - -class TestEnterpriseRouteChecksErrorMessages: - """Test that error messages correctly identify which feature requires Enterprise license""" - - @patch("litellm.secret_managers.main.get_secret_bool") - @patch("litellm.proxy.proxy_server.premium_user", False) - def test_disable_llm_api_endpoints_error_message(self, mock_get_secret_bool): - """ - Test that when DISABLE_LLM_API_ENDPOINTS is set without Enterprise license, - the error message correctly mentions 'LLM API ENDPOINTS' - """ - with patch.dict(os.environ, {"DISABLE_LLM_API_ENDPOINTS": "true"}): - with pytest.raises(HTTPException) as exc_info: - EnterpriseRouteChecks.is_llm_api_route_disabled() - - assert exc_info.value.status_code == 500 - assert "DISABLING LLM API ENDPOINTS is an Enterprise feature" in str( - exc_info.value.detail - ) - - @patch("litellm.secret_managers.main.get_secret_bool") - @patch("litellm.proxy.proxy_server.premium_user", False) - def test_disable_admin_endpoints_error_message(self, mock_get_secret_bool): - """ - Test that when DISABLE_ADMIN_ENDPOINTS is set without Enterprise license, - the error message correctly mentions 'ADMIN ENDPOINTS' (not 'LLM API ENDPOINTS') - - This is a regression test for a bug where the error message incorrectly said - 'DISABLING LLM API ENDPOINTS' when the actual issue was DISABLE_ADMIN_ENDPOINTS. - """ - with patch.dict(os.environ, {"DISABLE_ADMIN_ENDPOINTS": "true"}): - with pytest.raises(HTTPException) as exc_info: - EnterpriseRouteChecks.is_management_routes_disabled() - - assert exc_info.value.status_code == 500 - assert "DISABLING ADMIN ENDPOINTS is an Enterprise feature" in str( - exc_info.value.detail - ) - # Ensure it does NOT mention LLM API ENDPOINTS (the old buggy message) - assert "LLM API ENDPOINTS" not in str(exc_info.value.detail) - - @patch("litellm.secret_managers.main.get_secret_bool") - @patch("litellm.proxy.proxy_server.premium_user", True) - def test_disable_llm_api_endpoints_with_premium_user(self, mock_get_secret_bool): - """ - Test that premium users can use DISABLE_LLM_API_ENDPOINTS without error - """ - mock_get_secret_bool.return_value = True - with patch.dict(os.environ, {"DISABLE_LLM_API_ENDPOINTS": "true"}): - # Should not raise exception for premium users - result = EnterpriseRouteChecks.is_llm_api_route_disabled() - assert result is True - - @patch("litellm.secret_managers.main.get_secret_bool") - @patch("litellm.proxy.proxy_server.premium_user", True) - def test_disable_admin_endpoints_with_premium_user(self, mock_get_secret_bool): - """ - Test that premium users can use DISABLE_ADMIN_ENDPOINTS without error - """ - mock_get_secret_bool.return_value = True - with patch.dict(os.environ, {"DISABLE_ADMIN_ENDPOINTS": "true"}): - # Should not raise exception for premium users - result = EnterpriseRouteChecks.is_management_routes_disabled() - assert result is True diff --git a/tests/enterprise/litellm_enterprise/proxy/auth/test_user_api_key_auth.py b/tests/enterprise/litellm_enterprise/proxy/auth/test_user_api_key_auth.py deleted file mode 100644 index a45df5df00..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/auth/test_user_api_key_auth.py +++ /dev/null @@ -1,90 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastapi import Request -from litellm_enterprise.proxy.auth.user_api_key_auth import enterprise_custom_auth - - -@pytest.mark.asyncio -async def test_enterprise_custom_auth_none_user_auth(): - # Test when user_custom_auth is None - request = MagicMock(spec=Request) - result = await enterprise_custom_auth(request, "test-api-key", None) - assert result is None - - -@pytest.mark.asyncio -async def test_enterprise_custom_auth_mode_on(): - # Test when mode is "on" - mock_user_auth = AsyncMock(return_value={"user_id": "test-user"}) - request = MagicMock(spec=Request) - - with patch( - "litellm_enterprise.proxy.proxy_server.custom_auth_settings", {"mode": "on"} - ): - result = await enterprise_custom_auth(request, "test-api-key", mock_user_auth) - assert result == {"user_id": "test-user"} - mock_user_auth.assert_called_once_with(request, "test-api-key") - - -@pytest.mark.asyncio -async def test_enterprise_custom_auth_mode_auto_with_error(): - # Test when mode is "auto" and user_auth raises an exception - mock_user_auth = AsyncMock(side_effect=Exception("Auth failed")) - request = MagicMock(spec=Request) - - with patch( - "litellm_enterprise.proxy.proxy_server.custom_auth_settings", {"mode": "auto"} - ): - result = await enterprise_custom_auth(request, "test-api-key", mock_user_auth) - assert result is None - mock_user_auth.assert_called_once_with(request, "test-api-key") - - -@pytest.mark.asyncio -async def test_enterprise_custom_auth_returns_string(): - from litellm.proxy._types import hash_token - - # Test when enterprise_custom_auth returns a string (LiteLLM virtual key) - mock_user_auth = AsyncMock(return_value="sk-test-key") - request = MagicMock(spec=Request) - - with patch( - "litellm.proxy.auth.user_api_key_auth.enterprise_custom_auth", mock_user_auth - ), patch("litellm.proxy.proxy_server.master_key", "sk-1234"), patch( - "litellm.proxy.proxy_server.prisma_client", MagicMock() - ): - # Verify the key is correctly handled in _user_api_key_auth_builder - with patch( - "litellm.proxy.auth.user_api_key_auth.get_key_object" - ) as mock_get_key_object: - mock_get_key_object.return_value = MagicMock( - token="sk-test-key", - user_role="internal_user", - team_id=None, - user_id="test-user", - ) - - # Call _user_api_key_auth_builder with the returned key - from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder - - try: - auth_obj = await _user_api_key_auth_builder( - request=request, - api_key="my-custom-key", - azure_api_key_header="", - anthropic_api_key_header=None, - google_ai_studio_api_key_header=None, - azure_apim_header=None, - request_data={}, - custom_litellm_key_header=None, - ) - except Exception as e: - print("error:", e) - - # Verify get_key_object was called with the correct key - mock_get_key_object.assert_called_once() - # The key should be hashed before being passed to get_key_object - assert mock_get_key_object.call_args[1]["hashed_token"] == hash_token( - "sk-test-key" - ) diff --git a/tests/enterprise/litellm_enterprise/proxy/guardrails/conftest.py b/tests/enterprise/litellm_enterprise/proxy/guardrails/conftest.py deleted file mode 100644 index 4dd5c3d88c..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/guardrails/conftest.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Shared fixtures for guardrail apply_guardrail tests.""" - -from contextlib import contextmanager -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - - -@contextmanager -def _mock_proxy_logging(): - """Patch the proxy-server globals that apply_guardrail imports at call time.""" - mock_proxy_logging = MagicMock() - mock_proxy_logging.post_call_success_hook = AsyncMock(return_value=None) - mock_proxy_logging.post_call_failure_hook = AsyncMock(return_value=None) - mock_logging_obj = MagicMock() - mock_logging_obj.async_success_handler = AsyncMock(return_value=None) - mock_logging_obj.async_failure_handler = AsyncMock(return_value=None) - mock_logging_obj.success_handler = MagicMock(return_value=None) - mock_logging_obj.failure_handler = MagicMock(return_value=None) - mock_logging_obj.model_call_details = {} - - with ( - patch( - "litellm.proxy.common_request_processing.ProxyBaseLLMRequestProcessing" - ) as mock_proc_cls, - patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging), - patch("litellm.proxy.proxy_server.general_settings", {}), - patch("litellm.proxy.proxy_server.proxy_config", MagicMock()), - patch("litellm.proxy.proxy_server.version", "0.0.0"), - ): - mock_proc = MagicMock() - mock_proc.common_processing_pre_call_logic = AsyncMock( - return_value=({}, mock_logging_obj) - ) - mock_proc_cls.return_value = mock_proc - yield mock_proxy_logging - - -@pytest.fixture -def mock_proxy_logging_ctx(): - """Return the proxy-logging context manager factory for use as `with ctx():`.""" - return _mock_proxy_logging diff --git a/tests/enterprise/litellm_enterprise/proxy/guardrails/test_apply_guardrail_endpoint.py b/tests/enterprise/litellm_enterprise/proxy/guardrails/test_apply_guardrail_endpoint.py deleted file mode 100644 index e5074c4421..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/guardrails/test_apply_guardrail_endpoint.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Test the /guardrails/apply_guardrail endpoint -""" - -import os -import sys -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -sys.path.insert(0, os.path.abspath("../../../../..")) - -from fastapi import HTTPException - -from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.proxy._types import UserAPIKeyAuth -from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse - - -@pytest.mark.asyncio -async def test_apply_guardrail_endpoint_returns_correct_response( - mock_proxy_logging_ctx, -): - """Test that apply_guardrail endpoint returns ApplyGuardrailResponse object""" - from litellm.proxy.guardrails.guardrail_endpoints import apply_guardrail - - # Mock the guardrail registry - with ( - patch( - "litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY" - ) as mock_registry, - mock_proxy_logging_ctx(), - ): - # Create a mock guardrail - mock_guardrail = Mock(spec=CustomGuardrail) - # Apply guardrail returns GenericGuardrailAPIInputs (dict with texts key) - mock_guardrail.apply_guardrail = AsyncMock( - return_value={"texts": ["Redacted text: [REDACTED] and [REDACTED]"]} - ) - - # Configure the registry to return our mock guardrail - mock_registry.get_initialized_guardrail_callback.return_value = mock_guardrail - - # Create the request - request = ApplyGuardrailRequest( - guardrail_name="test-guardrail", - text="Test text with PII", - language="en", - entities=["EMAIL_ADDRESS", "PERSON"], - ) - - # Create a mock user API key - user_api_key_dict = UserAPIKeyAuth(api_key="test-key") - - # Call the endpoint - response = await apply_guardrail( - fastapi_request=Mock(), - request=request, - user_api_key_dict=user_api_key_dict, - ) - - # Verify the response is of the correct type - assert isinstance(response, ApplyGuardrailResponse) - assert response.response_text == "Redacted text: [REDACTED] and [REDACTED]" - - # Verify the guardrail was called with correct parameters - mock_guardrail.apply_guardrail.assert_called_once_with( - inputs={"texts": ["Test text with PII"]}, - request_data={}, - input_type="request", - ) - - -@pytest.mark.asyncio -async def test_apply_guardrail_endpoint_guardrail_not_found(mock_proxy_logging_ctx): - """Test that apply_guardrail endpoint raises exception when guardrail not found""" - from litellm.proxy._types import ProxyException - from litellm.proxy.guardrails.guardrail_endpoints import apply_guardrail - - # Mock the guardrail registry to return None - with ( - patch( - "litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY" - ) as mock_registry, - mock_proxy_logging_ctx(), - ): - mock_registry.get_initialized_guardrail_callback.return_value = None - - # Create the request - request = ApplyGuardrailRequest( - guardrail_name="non-existent-guardrail", text="Test text", language="en" - ) - - # Create a mock user API key - user_api_key_dict = UserAPIKeyAuth(api_key="test-key") - - # Verify exception is raised - with pytest.raises(ProxyException) as exc_info: - await apply_guardrail( - fastapi_request=Mock(), - request=request, - user_api_key_dict=user_api_key_dict, - ) - - assert "non-existent-guardrail" in exc_info.value.message - assert "not found" in exc_info.value.message - - -@pytest.mark.asyncio -async def test_apply_guardrail_endpoint_with_presidio_guardrail(mock_proxy_logging_ctx): - """Test apply_guardrail endpoint with a Presidio-like guardrail""" - from litellm.proxy.guardrails.guardrail_endpoints import apply_guardrail - - # Mock the guardrail registry - with ( - patch( - "litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY" - ) as mock_registry, - mock_proxy_logging_ctx(), - ): - # Create a mock guardrail that simulates Presidio behavior - mock_guardrail = Mock(spec=CustomGuardrail) - # Simulate masking PII entities - returns GenericGuardrailAPIInputs (dict with texts key) - mock_guardrail.apply_guardrail = AsyncMock( - return_value={ - "texts": ["My name is [PERSON] and my email is [EMAIL_ADDRESS]"] - } - ) - - # Configure the registry to return our mock guardrail - mock_registry.get_initialized_guardrail_callback.return_value = mock_guardrail - - # Create the request - request = ApplyGuardrailRequest( - guardrail_name="pii-detection-guard", - text="My name is John Doe and my email is john@example.com", - language="en", - entities=["EMAIL_ADDRESS", "PERSON"], - ) - - # Create a mock user API key - user_api_key_dict = UserAPIKeyAuth(api_key="test-key") - - # Call the endpoint - response = await apply_guardrail( - fastapi_request=Mock(), - request=request, - user_api_key_dict=user_api_key_dict, - ) - - # Verify the response is of the correct type - assert isinstance(response, ApplyGuardrailResponse) - assert ( - response.response_text - == "My name is [PERSON] and my email is [EMAIL_ADDRESS]" - ) - assert "john@example.com" not in response.response_text - assert "John Doe" not in response.response_text - - -@pytest.mark.asyncio -async def test_apply_guardrail_endpoint_without_optional_params(mock_proxy_logging_ctx): - """Test apply_guardrail endpoint without optional language and entities parameters""" - from litellm.proxy.guardrails.guardrail_endpoints import apply_guardrail - - # Mock the guardrail registry - with ( - patch( - "litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY" - ) as mock_registry, - mock_proxy_logging_ctx(), - ): - # Create a mock guardrail - mock_guardrail = Mock(spec=CustomGuardrail) - # Returns GenericGuardrailAPIInputs (dict with texts key) - mock_guardrail.apply_guardrail = AsyncMock( - return_value={"texts": ["Processed text"]} - ) - - # Configure the registry to return our mock guardrail - mock_registry.get_initialized_guardrail_callback.return_value = mock_guardrail - - # Create the request without optional parameters - request = ApplyGuardrailRequest( - guardrail_name="test-guardrail", text="Test text" - ) - - # Create a mock user API key - user_api_key_dict = UserAPIKeyAuth(api_key="test-key") - - # Call the endpoint - response = await apply_guardrail( - fastapi_request=Mock(), - request=request, - user_api_key_dict=user_api_key_dict, - ) - - # Verify the response is of the correct type - assert isinstance(response, ApplyGuardrailResponse) - assert response.response_text == "Processed text" - - # Verify the guardrail was called with correct parameters - mock_guardrail.apply_guardrail.assert_called_once_with( - inputs={"texts": ["Test text"]}, request_data={}, input_type="request" - ) diff --git a/tests/enterprise/litellm_enterprise/proxy/guardrails/test_bedrock_apply_guardrail.py b/tests/enterprise/litellm_enterprise/proxy/guardrails/test_bedrock_apply_guardrail.py deleted file mode 100644 index d1caf39854..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/guardrails/test_bedrock_apply_guardrail.py +++ /dev/null @@ -1,360 +0,0 @@ -""" -Test the Bedrock guardrail apply_guardrail functionality -""" - -import os -import sys -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -sys.path.insert(0, os.path.abspath("../../../../..")) - - -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import BedrockGuardrail -from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse - - -@pytest.mark.asyncio -async def test_bedrock_apply_guardrail_success(): - """Test that Bedrock guardrail apply_guardrail method works correctly""" - # Create a BedrockGuardrail instance - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - ) - - # Mock the make_bedrock_api_request method - with patch.object( - guardrail, "make_bedrock_api_request", new_callable=AsyncMock - ) as mock_api_request: - # Mock a successful response from Bedrock - mock_response = { - "action": "ALLOWED", - "output": [{"text": "This is a test message with some content"}], - } - mock_api_request.return_value = mock_response - - # Test the apply_guardrail method with new signature - guardrailed_inputs = await guardrail.apply_guardrail( - inputs={"texts": ["This is a test message with some content"]}, - request_data={}, - input_type="request", - ) - result = guardrailed_inputs.get("texts", []) - - # Verify the result - assert result == ["This is a test message with some content"] - mock_api_request.assert_called_once() - - -@pytest.mark.asyncio -async def test_bedrock_apply_guardrail_blocked(): - """Test that apply_guardrail lets HTTPException propagate as-is for blocked content. - - Regression test for issue #20045: when disable_exception_on_block=False (default), - make_bedrock_api_request raises HTTPException for BLOCKED content. apply_guardrail - must NOT wrap it in a generic Exception, otherwise the proxy loses the HTTP 400 - status and fails to block the LLM call. - """ - from fastapi import HTTPException - - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - ) - - with patch.object( - guardrail, "make_bedrock_api_request", new_callable=AsyncMock - ) as mock_api_request: - mock_api_request.side_effect = HTTPException( - status_code=400, - detail={ - "error": "Violated guardrail policy", - "bedrock_guardrail_response": "Content blocked", - }, - ) - - # Test the apply_guardrail method propagates HTTPException (AWS error) to the client - with pytest.raises(HTTPException) as exc_info: - await guardrail.apply_guardrail( - inputs={"texts": ["This is blocked content"]}, - request_data={}, - input_type="request", - ) - - assert exc_info.value.status_code == 400 - assert "Violated guardrail policy" in str(exc_info.value.detail) - - -@pytest.mark.asyncio -async def test_bedrock_apply_guardrail_with_masking(): - """Test that Bedrock guardrail apply_guardrail method handles content masking""" - # Create a BedrockGuardrail instance - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - ) - - # Mock the make_bedrock_api_request method - with patch.object( - guardrail, "make_bedrock_api_request", new_callable=AsyncMock - ) as mock_api_request: - # Mock a response with masked content - mock_response = { - "action": "ALLOWED", - "outputs": [{"text": "This is a test message with [REDACTED] content"}], - } - mock_api_request.return_value = mock_response - - # Test the apply_guardrail method with new signature - guardrailed_inputs = await guardrail.apply_guardrail( - inputs={"texts": ["This is a test message with sensitive content"]}, - request_data={}, - input_type="request", - ) - result = guardrailed_inputs.get("texts", []) - - # Verify the result contains the masked content - assert result == ["This is a test message with [REDACTED] content"] - mock_api_request.assert_called_once() - - -@pytest.mark.asyncio -async def test_bedrock_apply_guardrail_api_failure(): - """Test that Bedrock guardrail apply_guardrail method handles API failures""" - # Create a BedrockGuardrail instance - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - ) - - # Mock the make_bedrock_api_request method to raise an exception - with patch.object( - guardrail, "make_bedrock_api_request", new_callable=AsyncMock - ) as mock_api_request: - mock_api_request.side_effect = Exception("API connection failed") - - # Test the apply_guardrail method should raise an exception - with pytest.raises(Exception) as exc_info: - await guardrail.apply_guardrail( - inputs={"texts": ["This is a test message"]}, - request_data={}, - input_type="request", - ) - - # The error message should contain the original exception - assert "API connection failed" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_bedrock_apply_guardrail_endpoint_integration(mock_proxy_logging_ctx): - """Test the full endpoint integration with Bedrock guardrail""" - from litellm.proxy.guardrails.guardrail_endpoints import apply_guardrail - - # Create a real BedrockGuardrail instance - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - ) - - # Mock the guardrail registry - with ( - patch( - "litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY" - ) as mock_registry, - mock_proxy_logging_ctx(), - ): - # Mock the make_bedrock_api_request method - with patch.object( - guardrail, "make_bedrock_api_request", new_callable=AsyncMock - ) as mock_api_request: - # Mock a successful response from Bedrock - mock_response = { - "action": "ALLOWED", - "outputs": [{"text": "This is a test message with processed content"}], - } - mock_api_request.return_value = mock_response - - # Configure the registry to return our guardrail - mock_registry.get_initialized_guardrail_callback.return_value = guardrail - - # Create the request - request = ApplyGuardrailRequest( - guardrail_name="test-bedrock-guard", - text="This is a test message with some content", - language="en", - ) - - # Create a mock user API key - user_api_key_dict = UserAPIKeyAuth(api_key="test-key") - - # Call the endpoint - response = await apply_guardrail( - fastapi_request=Mock(), - request=request, - user_api_key_dict=user_api_key_dict, - ) - - # Verify the response - assert isinstance(response, ApplyGuardrailResponse) - assert ( - response.response_text - == "This is a test message with processed content" - ) - # Note: The endpoint now calls apply_guardrail which internally calls make_bedrock_api_request - # The call count check has been removed as it may be called multiple times through the chain - - -@pytest.mark.asyncio -async def test_bedrock_apply_guardrail_filters_request_messages_when_flag_enabled(): - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - experimental_use_latest_role_message_only=True, - ) - - request_messages = [ - {"role": "system", "content": "rules"}, - {"role": "user", "content": "first question"}, - {"role": "assistant", "content": "response"}, - {"role": "user", "content": "latest question"}, - ] - - request_data = {"messages": request_messages} - - with patch.object( - guardrail, "make_bedrock_api_request", new_callable=AsyncMock - ) as mock_api: - mock_api.return_value = { - "action": "ALLOWED", - "output": [{"text": "latest question"}], - } - - guardrailed_inputs = await guardrail.apply_guardrail( - inputs={"texts": ["latest question"]}, - request_data=request_data, - input_type="request", - ) - result = guardrailed_inputs.get("texts", []) - - assert mock_api.called - _, kwargs = mock_api.call_args - assert kwargs["messages"] == [request_messages[-1]] - assert result == ["latest question"] - - -@pytest.mark.asyncio -async def test_bedrock_apply_guardrail_filters_request_messages_when_flag_enabled_blocked(): - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - experimental_use_latest_role_message_only=True, - ) - - request_messages = [ - {"role": "user", "content": "first"}, - {"role": "user", "content": "blocked"}, - ] - - request_data = {"messages": request_messages} - - with patch.object( - guardrail, "make_bedrock_api_request", new_callable=AsyncMock - ) as mock_api: - # Mock the method to raise an HTTPException as it would for blocked content - from fastapi import HTTPException - - mock_api.side_effect = HTTPException( - status_code=400, - detail={ - "error": "Violated guardrail policy", - "bedrock_guardrail_response": "policy", - }, - ) - - with pytest.raises(HTTPException, match="policy") as exc_info: - await guardrail.apply_guardrail( - inputs={"texts": ["blocked"]}, - request_data=request_data, - input_type="request", - ) - - assert mock_api.called - _, kwargs = mock_api.call_args - assert kwargs["messages"] == [request_messages[-1]] - # HTTPException from guardrail is propagated so the client gets the AWS message - assert exc_info.value.status_code == 400 - assert "policy" in str(exc_info.value.detail) - - -def test_bedrock_guardrail_filters_latest_user_message_when_enabled(): - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - experimental_use_latest_role_message_only=True, - ) - - messages = [ - {"role": "system", "content": "rules"}, - {"role": "user", "content": "first question"}, - {"role": "assistant", "content": "response"}, - {"role": "user", "content": "latest question"}, - ] - - filter_result = guardrail._prepare_guardrail_messages_for_role(messages=messages) - - assert filter_result.payload_messages is not None - assert len(filter_result.payload_messages) == 1 - assert filter_result.payload_messages[0]["content"] == "latest question" - assert filter_result.target_indices == [3] - - masked_messages = guardrail._merge_filtered_messages( - original_messages=filter_result.original_messages, - updated_target_messages=[{"role": "user", "content": "[MASKED]"}], - target_indices=filter_result.target_indices, - ) - assert masked_messages[3]["content"] == "[MASKED]" - - -@pytest.mark.asyncio -async def test_bedrock_apply_guardrail_blocked_with_disable_exception_on_block(): - """ - Regression test for issue #20045: when disable_exception_on_block=True, - make_bedrock_api_request raises GuardrailInterventionNormalStringError. - apply_guardrail must let it propagate as-is so the proxy can handle it - properly instead of wrapping it in a generic Exception. - """ - from litellm.exceptions import GuardrailInterventionNormalStringError - - guardrail = BedrockGuardrail( - guardrail_name="test-bedrock-guard", - guardrailIdentifier="test-guard-id", - guardrailVersion="DRAFT", - disable_exception_on_block=True, - ) - - with patch.object( - guardrail, "make_bedrock_api_request", new_callable=AsyncMock - ) as mock_api: - mock_api.side_effect = GuardrailInterventionNormalStringError( - message="Sorry, your question in its current format is unable to be answered." - ) - - with pytest.raises(GuardrailInterventionNormalStringError) as exc_info: - await guardrail.apply_guardrail( - inputs={"texts": ["harmful prompt content"]}, - request_data={}, - input_type="request", - ) - - assert "unable to be answered" in str(exc_info.value.message) diff --git a/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py b/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py deleted file mode 100644 index 7b82d1eabd..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py +++ /dev/null @@ -1,2307 +0,0 @@ -import base64 -import json -from typing import cast -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastapi import HTTPException -from litellm_enterprise.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles - -from litellm.caching import DualCache -from litellm.proxy._types import CallTypes -from litellm.proxy.openai_files_endpoints.common_utils import ( - _is_base64_encoded_unified_file_id, -) - - -def test_get_file_ids_from_messages(): - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=MagicMock() - ) - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is in this recording?"}, - { - "type": "file", - "file": { - "file_id": "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCxmYzdmMmVhNS0wZjUwLTQ5ZjYtODljMS03ZTZhNTRiMTIxMzg", - }, - }, - ], - }, - ] - file_ids = proxy_managed_files.get_file_ids_from_messages(messages) - assert file_ids == [ - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCxmYzdmMmVhNS0wZjUwLTQ5ZjYtODljMS03ZTZhNTRiMTIxMzg" - ] - - -def test_get_file_ids_from_messages_skips_bedrock_content_blocks_without_type(): - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=MagicMock() - ) - messages = [ - { - "role": "user", - "content": [ - {"text": "What is Apptio?"}, - { - "toolResult": { - "toolUseId": "tooluse_123", - "status": "success", - "content": [ - { - "searchResult": { - "source": "source", - "title": "title", - "content": [{"text": "snippet"}], - "citations": {"enabled": True}, - } - } - ], - } - }, - {"type": "file", "file": {"file_id": "file-keep"}}, - ], - } - ] - file_ids = proxy_managed_files.get_file_ids_from_messages(messages) - assert file_ids == ["file-keep"] - - -@pytest.mark.asyncio -async def test_async_pre_call_hook_batch_retrieve(): - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - return_value = MagicMock() - return_value.created_by = "123" - prisma_client.db.litellm_managedobjecttable.find_first.return_value = return_value - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - data = { - "user_api_key_dict": UserAPIKeyAuth( - user_id="123", parent_otel_span=MagicMock() - ), - "data": { - "batch_id": "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1nZW5lcmFsLWF6dXJlLWRlcGxveW1lbnQ7bGxtX2JhdGNoX2lkOmJhdGNoX2EzMjJiNmJhLWFjN2UtNDg4OC05MjljLTFhZDM0NDJmMDZlZA", - }, - "call_type": "aretrieve_batch", - "cache": MagicMock(), - } - response = await proxy_managed_files.async_pre_call_hook(**data) - assert response["batch_id"] == "batch_a322b6ba-ac7e-4888-929c-1ad3442f06ed" - assert response["model"] == "my-general-azure-deployment" - - -@pytest.mark.asyncio -async def test_async_pre_call_deployment_hook_resolves_model_id_from_litellm_metadata(): - """ - For batch operations the router stores model_info under - kwargs["litellm_metadata"]["model_info"] (not top-level kwargs["model_info"]). - async_pre_call_deployment_hook must check both locations so the managed - file ID is resolved to the provider-specific file ID. - """ - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=MagicMock() - ) - - managed_file_id = "managed-file-abc" - model_id = "deployment-xyz" - provider_file_id = "gs://bucket/path/to/file.jsonl" - - # model_info is nested under litellm_metadata (batch path) - kwargs = { - "input_file_id": managed_file_id, - "model_file_id_mapping": { - managed_file_id: {model_id: provider_file_id}, - }, - "litellm_metadata": { - "model_info": {"id": model_id}, - }, - } - - result = await proxy_managed_files.async_pre_call_deployment_hook( - kwargs=kwargs, call_type=CallTypes.acreate_batch - ) - - assert ( - result["input_file_id"] == provider_file_id - ), f"Expected provider file ID '{provider_file_id}', got '{result['input_file_id']}'" - - -@pytest.mark.asyncio -async def test_async_pre_call_deployment_hook_prefers_top_level_model_info(): - """ - When model_info exists at top-level kwargs, async_pre_call_deployment_hook - should use it without falling back to litellm_metadata. - """ - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=MagicMock() - ) - - managed_file_id = "managed-file-abc" - top_level_model_id = "deployment-top" - nested_model_id = "deployment-nested" - top_level_provider_file = "file-top-123" - nested_provider_file = "file-nested-456" - - kwargs = { - "input_file_id": managed_file_id, - "model_file_id_mapping": { - managed_file_id: { - top_level_model_id: top_level_provider_file, - nested_model_id: nested_provider_file, - }, - }, - "model_info": {"id": top_level_model_id}, - "litellm_metadata": { - "model_info": {"id": nested_model_id}, - }, - } - - result = await proxy_managed_files.async_pre_call_deployment_hook( - kwargs=kwargs, call_type=CallTypes.acreate_batch - ) - - assert ( - result["input_file_id"] == top_level_provider_file - ), "Should prefer top-level model_info over litellm_metadata" - - -@pytest.mark.asyncio -async def test_async_pre_call_deployment_hook_no_model_info_leaves_file_id_unchanged(): - """ - When model_info is absent from both top-level and litellm_metadata, - the managed file ID should remain unchanged. - """ - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=MagicMock() - ) - - managed_file_id = "managed-file-abc" - - kwargs = { - "input_file_id": managed_file_id, - "model_file_id_mapping": { - managed_file_id: {"some-model": "provider-file-xyz"}, - }, - } - - result = await proxy_managed_files.async_pre_call_deployment_hook( - kwargs=kwargs, call_type=CallTypes.acreate_batch - ) - - assert ( - result["input_file_id"] == managed_file_id - ), "File ID should remain unchanged when model_info is not available" - - -# def test_list_managed_files(): -# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) - -# # Create some test files -# file1 = proxy_managed_files.create_file( -# file=("test1.txt", b"test content 1", "text/plain"), -# purpose="assistants" -# ) -# file2 = proxy_managed_files.create_file( -# file=("test2.pdf", b"test content 2", "application/pdf"), -# purpose="assistants" -# ) - -# # List all files -# files = proxy_managed_files.list_files() - -# # Verify response -# assert len(files) == 2 -# assert all(f.id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value) for f in files) -# assert any(f.filename == "test1.txt" for f in files) -# assert any(f.filename == "test2.pdf" for f in files) -# assert all(f.purpose == "assistants" for f in files) - -# def test_retrieve_managed_file(): -# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) - -# # Create a test file -# test_content = b"test content for retrieve" -# created_file = proxy_managed_files.create_file( -# file=("test.txt", test_content, "text/plain"), -# purpose="assistants" -# ) - -# # Retrieve the file -# retrieved_file = proxy_managed_files.retrieve_file(created_file.id) - -# # Verify response -# assert retrieved_file.id == created_file.id -# assert retrieved_file.filename == "test.txt" -# assert retrieved_file.purpose == "assistants" -# assert retrieved_file.bytes == len(test_content) -# assert retrieved_file.status == "uploaded" - -# def test_delete_managed_file(): -# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) - -# # Create a test file -# created_file = proxy_managed_files.create_file( -# file=("test.txt", b"test content", "text/plain"), -# purpose="assistants" -# ) - -# # Delete the file -# deleted_file = proxy_managed_files.delete_file(created_file.id) - -# # Verify deletion -# assert deleted_file.id == created_file.id -# assert deleted_file.deleted == True - -# # Verify file is no longer retrievable -# with pytest.raises(Exception): -# proxy_managed_files.retrieve_file(created_file.id) - -# # Verify file is not in list -# files = proxy_managed_files.list_files() -# assert created_file.id not in [f.id for f in files] - -# def test_retrieve_nonexistent_file(): -# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) - -# # Try to retrieve a non-existent file -# with pytest.raises(Exception): -# proxy_managed_files.retrieve_file("nonexistent-file-id") - -# def test_delete_nonexistent_file(): -# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) - -# # Try to delete a non-existent file -# with pytest.raises(Exception): -# proxy_managed_files.delete_file("nonexistent-file-id") - -# def test_list_files_with_purpose_filter(): -# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) - -# # Create files with different purposes -# file1 = proxy_managed_files.create_file( -# file=("test1.txt", b"test content 1", "text/plain"), -# purpose="assistants" -# ) -# file2 = proxy_managed_files.create_file( -# file=("test2.pdf", b"test content 2", "application/pdf"), -# purpose="batch" -# ) - -# # List files with purpose filter -# assistant_files = proxy_managed_files.list_files(purpose="assistants") -# batch_files = proxy_managed_files.list_files(purpose="batch") - -# # Verify filtering -# assert len(assistant_files) == 1 -# assert len(batch_files) == 1 -# assert assistant_files[0].id == file1.id -# assert batch_files[0].id == file2.id - - -@pytest.mark.asyncio -async def test_async_post_call_success_hook_for_unified_finetuning_job(): - from litellm.types.utils import LiteLLMFineTuningJob - - unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9vY3RldC1zdHJlYW07dW5pZmllZF9pZCxiZTQ0ZDVlYi1mNDU3LTRiNzktOWM4My01N2QxMTMxYWM0YzY7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00LjEtb3BlbmFpO2xsbV9vdXRwdXRfZmlsZV9pZCxmaWxlLURKMnQ0OWZlQ2NTQk5vNG9oekZ6NGc7bGxtX291dHB1dF9maWxlX21vZGVsX2lkLGRiNjY5ODcwNzdkZTdmYzZjNzAzY2Y1MDczMGU2MmNkOWQ3YTU1N2NlNjVmMDUzNTFkYTM4YTA3ZjBlZDEyNzQ" - provider_ft_job = LiteLLMFineTuningJob( - object="fine_tuning.job", - id="ftjob-0kEBV5b4sPrFcMnuzmYSzU1G", - model="gpt-3.5-turbo-0613", - created_at=1692779769, - finished_at=None, - fine_tuned_model=None, - organization_id="org-dUVLhaAQ37YCGwVC2QVY8sdB", - result_files=[], - status="validating_files", - validation_file=None, - training_file="file-azQuKMLAmiFdEjxpCcbI11zF", - hyperparameters={"n_epochs": 8}, - trained_tokens=None, - seed=0, - ) - provider_ft_job._hidden_params = { - "unified_file_id": unified_file_id, - "model_id": "gpt-3.5-turbo-0613", - } - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=AsyncMock() - ) - data = { - "user_api_key_dict": {"parent_otel_span": MagicMock()}, - } - - response = await proxy_managed_files.async_post_call_success_hook( - data=data, - user_api_key_dict=MagicMock(), - response=provider_ft_job, - ) - - assert isinstance(response, LiteLLMFineTuningJob) - assert _is_base64_encoded_unified_file_id(response.id) - - -@pytest.mark.asyncio -async def test_async_pre_call_hook_for_unified_finetuning_job(): - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - return_value = MagicMock() - return_value.created_by = "123" - prisma_client.db.litellm_managedobjecttable.find_first.return_value = return_value - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - data = { - "user_api_key_dict": UserAPIKeyAuth( - user_id="123", parent_otel_span=MagicMock() - ), - "data": { - "fine_tuning_job_id": "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDo0OTIxODU4MWY3OGViZTllZjE4NDE0ZmE0ZjdmYjlmYTc0YzA5NWVkMTEyY2E4NDBkZDU2ZGZmZTliZDMwZGQxO2dlbmVyaWNfcmVzcG9uc2VfaWQ6ZnRqb2ItalRCeXM3YlZzYnlaRE93TDlHbHBZcVhS", - }, - "call_type": "acancel_fine_tuning_job", - "cache": MagicMock(), - } - - response = await proxy_managed_files.async_pre_call_hook(**data) - assert response["fine_tuning_job_id"] == "ftjob-jTBys7bVsbyZDOwL9GlpYqXR" - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "call_type", ["afile_content", "afile_delete", "afile_retrieve"] -) -async def test_can_user_call_unified_file_id(call_type): - """ - Test that on file retrieve, delete, and content we check if the user has access to the file - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - return_value = MagicMock() - return_value.created_by = "123" - prisma_client.db.litellm_managedfiletable.find_first.return_value = return_value - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - MagicMock(), prisma_client=prisma_client - ) - unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9vY3RldC1zdHJlYW07dW5pZmllZF9pZCxmMTNlNDAzZS01YWM3LTRhZjktOGQzNS0wNDgwZDMxOTgyYTg7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00by1taW5pLW9wZW5haTtsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1Ib3UxZDFXc3c1SDNKcjFMYllpZDJiO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxmODBiNWU2NzQ1NzdkNjkyMjM4YmVhNTIxZDdiMGI5ZGYyY2FmMTEwMTU2YmU5YzBjM2NjMmNkNTBjOTM1ZDI0" - - with pytest.raises(HTTPException): - await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="456", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"file_id": unified_file_id}, - call_type=call_type, - ) - - -@pytest.mark.asyncio -async def test_router_acreate_batch_only_selects_from_file_id_mapping(monkeypatch): - """ - Test that router.acreate_batch only selects model_id from the file_id_mapping - """ - import litellm - - prisma_client = AsyncMock() - return_value = MagicMock() - return_value.created_by = "123" - prisma_client.db.litellm_managedobjecttable.find_first.return_value = return_value - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - monkeypatch.setattr( - litellm, - "callbacks", - [proxy_managed_files], - ) - - router = litellm.Router( - model_list=[ - { - "model_name": "gpt-5-mini", - "litellm_params": {"model": "gpt-5-mini"}, - "model_info": {"id": "1234"}, - }, - { - "model_name": "gpt-5-mini", - "litellm_params": {"model": "gpt-5-mini"}, - "model_info": {"id": "5678"}, - }, - ], - ) - - file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9vY3RldC1zdHJlYW07dW5pZmllZF9pZCw2YmQ4ZjhhYS02NmEzLTRmY2MtOTIxZS1lMTYwYzIzZWZjNzU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1MTENVRkI1MnVUTWE5aE5ZanRldzlWO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxmMzJlNWQ0OC05YWZmLTQ5YjMtOWE1Ny0zYzJhN2JjN2NjMmE" - - model_file_id_mapping = {file_id: {"5678": "file-LLCUFB52uTMa9hNYjtew9V"}} - - with patch.object( - litellm, "acreate_batch", return_value=AsyncMock() - ) as mock_acreate_batch: - for _ in range(1000): - await router.acreate_batch( - model="gpt-5-mini", - input_file_id=file_id, - model_file_id_mapping=model_file_id_mapping, - ) - - mock_acreate_batch.assert_called() - assert "5678" in json.dumps(mock_acreate_batch.call_args.kwargs) - - -@pytest.mark.asyncio -async def test_output_file_id_for_batch_retrieve(): - """ - Test that the output file id is the same as the input file id - """ - from typing import cast - - from openai.types.batch import BatchRequestCounts - - from litellm.types.utils import LiteLLMBatch - - batch = LiteLLMBatch( - id="bGl0ZWxsbV9wcm94eTttb2RlbF9pZDoxMjM0NTY3OTtsbG1fYmF0Y2hfaWQ6YmF0Y2hfNjg1YzVlNWQ2Mzk4ODE5MGI4NWJkYjIxNDdiYTEzMWQ", - completion_window="24h", - created_at=1750883933, - endpoint="/v1/chat/completions", - input_file_id="file-8ci8gux8s7oES7GydYvnMG", - object="batch", - status="completed", - cancelled_at=None, - cancelling_at=None, - completed_at=1750883939, - error_file_id=None, - errors=None, - expired_at=None, - expires_at=1750970333, - failed_at=None, - finalizing_at=1750883938, - in_progress_at=1750883934, - metadata={"description": "nightly eval job"}, - output_file_id="file-3BZYhmdJQ3V2oZPAtQsEax", - request_counts=BatchRequestCounts(completed=1, failed=0, total=1), - usage=None, - ) - - batch._hidden_params = { - "litellm_call_id": "dcd789e0-c0ad-4244-9564-4e611448d650", - "api_base": "https://api.openai.com", - "model_id": "12345679", - "response_cost": 0.0, - "additional_headers": {}, - "litellm_model_name": "gpt-5.5", - "unified_batch_id": "litellm_proxy;model_id:12345679;llm_batch_id:batch_685c5e5d63988190b85bdb2147ba131d", - } - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=AsyncMock() - ) - - response = await proxy_managed_files.async_post_call_success_hook( - data={}, - user_api_key_dict=MagicMock(), - response=batch, - ) - - assert not cast(LiteLLMBatch, response).output_file_id.startswith("file-") - - -@pytest.mark.asyncio -async def test_output_file_id_preserves_target_model_names_when_model_name_missing(): - """ - Regression test: when provider response does not include _hidden_params.model_name - (e.g. Vertex batch retrieve), unified output_file_id should still include - target_model_names from the managed input file ID. - """ - from openai.types.batch import BatchRequestCounts - - from litellm.proxy._types import UserAPIKeyAuth - from litellm.types.llms.openai import OpenAIFileObject - from litellm.types.utils import LiteLLMBatch - - batch = LiteLLMBatch( - id="batch_123", - completion_window="24h", - created_at=1750883933, - endpoint="/v1/chat/completions", - input_file_id="file-input-provider-id", - object="batch", - status="completed", - output_file_id="file-provider-output-id", - request_counts=BatchRequestCounts(completed=1, failed=0, total=1), - usage=None, - ) - - # Build a valid managed input id string and base64 encode it. - managed_input_file_payload = ( - "litellm_proxy:application/octet-stream;" - "unified_id,test-uuid;" - "target_model_names,gemini-2.5-pro;" - "llm_output_file_id,file-input-1;" - "llm_output_file_model_id,model-id-1" - ) - managed_input_file_id = ( - base64.urlsafe_b64encode(managed_input_file_payload.encode()) - .decode() - .rstrip("=") - ) - - batch._hidden_params = { - "model_id": "model-id-1", - "unified_batch_id": "litellm_proxy;model_id:model-id-1;llm_batch_id:batch_123", - "unified_file_id": managed_input_file_id, - # Intentionally omit model_name to mimic Vertex issue. - } - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=AsyncMock() - ) - - provider_output_file = OpenAIFileObject( - id="file-provider-output-id", - object="file", - bytes=1, - created_at=1, - filename="predictions.jsonl", - purpose="batch_output", - ) - - with patch("litellm.afile_retrieve", new_callable=AsyncMock) as mock_retrieve: - mock_retrieve.return_value = provider_output_file - response = await proxy_managed_files.async_post_call_success_hook( - data={}, - user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), - response=batch, - ) - - decoded_output_file_id = _is_base64_encoded_unified_file_id( - cast(LiteLLMBatch, response).output_file_id - ) - assert decoded_output_file_id - assert "target_model_names,gemini-2.5-pro" in cast(str, decoded_output_file_id) - - -@pytest.mark.asyncio -async def test_error_file_id_for_failed_batch(): - """ - Test that the error_file_id is properly managed when a batch fails - """ - from typing import cast - - from openai.types.batch import BatchRequestCounts - - from litellm.proxy._types import UserAPIKeyAuth - from litellm.types.llms.openai import OpenAIFileObject - from litellm.types.utils import LiteLLMBatch - - batch = LiteLLMBatch( - id="bGl0ZWxsbV9wcm94eTttb2RlbF9pZDoxMjM0NTY3OTtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz", - completion_window="24h", - created_at=1714508499, - endpoint="/v1/chat/completions", - input_file_id="file-abc123", - object="batch", - status="failed", - cancelled_at=None, - cancelling_at=None, - completed_at=None, - error_file_id="error-abc123", - errors=None, - expired_at=None, - expires_at=1714536634, - failed_at=None, - finalizing_at=None, - in_progress_at=None, - metadata=None, - output_file_id=None, - request_counts=BatchRequestCounts(completed=0, failed=0, total=0), - usage=None, - ) - - batch._hidden_params = { - "litellm_call_id": "test-call-id", - "api_base": "https://api.openai.com", - "model_id": "test-model-id", - "model_name": "gpt-5.5", - "response_cost": 0.0, - "additional_headers": {}, - "litellm_model_name": "gpt-5.5", - "unified_batch_id": "litellm_proxy;model_id:test-model-id;llm_batch_id:batch_abc123", - } - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=AsyncMock() - ) - - # Create a proper OpenAIFileObject for the error file - error_file_object = OpenAIFileObject( - id="error-abc123", - object="file", - bytes=1234, - created_at=1714508500, - filename="error.jsonl", - purpose="batch_output", - status="processed", - ) - - # Mock the afile_retrieve to simulate retrieving error file metadata - with patch("litellm.afile_retrieve", new_callable=AsyncMock) as mock_retrieve: - mock_retrieve.return_value = error_file_object - - user_api_key_dict = UserAPIKeyAuth( - user_id="test-user-123", parent_otel_span=MagicMock() - ) - - response = await proxy_managed_files.async_post_call_success_hook( - data={}, - user_api_key_dict=user_api_key_dict, - response=batch, - ) - - # Verify that error_file_id was transformed to a managed file ID - assert cast(LiteLLMBatch, response).error_file_id is not None - assert not cast(LiteLLMBatch, response).error_file_id.startswith("error-") - # Verify it's a base64 encoded managed file ID - assert _is_base64_encoded_unified_file_id( - cast(LiteLLMBatch, response).error_file_id - ) - - -@pytest.mark.asyncio -async def test_async_post_call_success_hook_twice_assert_no_unique_violation(): - import asyncio - - from openai.types.batch import BatchRequestCounts - - from litellm.proxy._types import UserAPIKeyAuth - from litellm.types.utils import LiteLLMBatch - - # Use AsyncMock instead of real database connection - prisma_client = AsyncMock() - - batch = LiteLLMBatch( - id="bGl0ZWxsbV9wcm94eTttb2RlbF9pZDoxMjM0NTY3OTtsbG1fYmF0Y2hfaWQ6YmF0Y2hfNjg1YzVlNWQ2Mzk4ODE5MGI4NWJkYjIxNDdiYTEzMWQ", - completion_window="24h", - created_at=1750883933, - endpoint="/v1/chat/completions", - input_file_id="file-8ci8gux8s7oES7GydYvnMG", - object="batch", - status="completed", - metadata={"description": "nightly eval job"}, - request_counts=BatchRequestCounts(completed=1, failed=0, total=1), - usage=None, - ) - - batch._hidden_params = { - "model_id": "12345679", - "response_cost": 0.0, - "litellm_model_name": "gpt-5.5", - "unified_batch_id": "litellm_proxy;model_id:12345679;llm_batch_id:batch_685c5e5d63988190b85bdb2147ba131d", - } - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - # first retrieve batch - tasks = [] - first_create_task = asyncio.create_task - with patch("asyncio.create_task") as mock_create_task: - mock_create_task.side_effect = ( - lambda coro: tasks.append(first_create_task(coro)) or tasks[-1] - ) - - response = await proxy_managed_files.async_post_call_success_hook( - data={}, - user_api_key_dict=UserAPIKeyAuth(user_id="default_id"), - response=batch.copy(), - ) - - if tasks: - # make sure asyncio(db create) is finished - await asyncio.sleep(0.02) - await asyncio.gather(*tasks, return_exceptions=True) - for task in tasks: - assert task.exception() is None, f"Error: {task.exception()}" - - assert isinstance(response, LiteLLMBatch) - assert _is_base64_encoded_unified_file_id(response.id) - - # second retrieve batch - tasks = [] - second_create_task = asyncio.create_task - with patch("asyncio.create_task") as mock_create_task: - mock_create_task.side_effect = ( - lambda coro: tasks.append(second_create_task(coro)) or tasks[-1] - ) - - await proxy_managed_files.async_post_call_success_hook( - data={}, - user_api_key_dict=UserAPIKeyAuth(user_id="default_id"), - response=batch.copy(), - ) - - if tasks: - await asyncio.sleep(0.01) - await asyncio.gather(*tasks, return_exceptions=True) - for task in tasks: - assert task.exception() is None, f"Error: {task.exception()}" - - -def test_update_responses_input_with_unified_file_id(): - """ - Test that update_responses_input_with_model_file_ids correctly decodes - unified file IDs and extracts llm_output_file_id from responses API input. - """ - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - update_responses_input_with_model_file_ids, - ) - - # Create a base64-encoded unified file ID - # This decodes to: litellm_proxy:application/pdf;unified_id,6c0b5890-8914-48e0-b8f4-0ae5ed3c14a5;target_model_names,gpt-4o;llm_output_file_id,file-ECBPW7ML9g7XHdwGgUPZaM;llm_output_file_model_id,e26453f9e76e7993680d0068d98c1f4cc205bbad0967a33c664893568ca743c2 - unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - - # Test input with unified file ID in content array - input_data = [ - { - "role": "user", - "content": [ - { - "type": "input_file", - "file_id": unified_file_id, - }, - { - "type": "input_text", - "text": "What is the first dragon in the book?", - }, - ], - } - ] - - # Update the input - updated_input = update_responses_input_with_model_file_ids(input=input_data) - - # Verify the file_id was updated to the provider-specific file ID - assert updated_input[0]["content"][0]["type"] == "input_file" - assert updated_input[0]["content"][0]["file_id"] == "file-ECBPW7ML9g7XHdwGgUPZaM" - assert updated_input[0]["content"][1]["type"] == "input_text" - assert ( - updated_input[0]["content"][1]["text"] - == "What is the first dragon in the book?" - ) - - -def test_update_responses_input_with_regular_file_id(): - """ - Test that update_responses_input_with_model_file_ids keeps regular - OpenAI file IDs unchanged. - """ - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - update_responses_input_with_model_file_ids, - ) - - # Regular OpenAI file ID (not a unified file ID) - regular_file_id = "file-abc123xyz" - - input_data = [ - { - "role": "user", - "content": [ - { - "type": "input_file", - "file_id": regular_file_id, - }, - { - "type": "input_text", - "text": "What is this file?", - }, - ], - } - ] - - # Update the input - updated_input = update_responses_input_with_model_file_ids(input=input_data) - - # Verify the file_id was kept unchanged (regular OpenAI file ID) - assert updated_input[0]["content"][0]["type"] == "input_file" - assert updated_input[0]["content"][0]["file_id"] == regular_file_id - assert updated_input[0]["content"][1]["type"] == "input_text" - - -def test_update_responses_input_with_string_input(): - """ - Test that update_responses_input_with_model_file_ids returns string input unchanged. - """ - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - update_responses_input_with_model_file_ids, - ) - - input_data = "What is AI?" - - updated_input = update_responses_input_with_model_file_ids(input=input_data) - - assert updated_input == input_data - assert isinstance(updated_input, str) - - -def test_update_responses_input_with_multiple_file_ids(): - """ - Test that update_responses_input_with_model_file_ids handles multiple file IDs - (both unified and regular) in the same input. - """ - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - update_responses_input_with_model_file_ids, - ) - - # Unified file ID - unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - # Regular OpenAI file ID - regular_file_id = "file-regular123" - - input_data = [ - { - "role": "user", - "content": [ - { - "type": "input_file", - "file_id": unified_file_id, - }, - { - "type": "input_text", - "text": "Compare these files", - }, - { - "type": "input_file", - "file_id": regular_file_id, - }, - ], - } - ] - - updated_input = update_responses_input_with_model_file_ids(input=input_data) - - # Verify unified file ID was updated - assert updated_input[0]["content"][0]["file_id"] == "file-ECBPW7ML9g7XHdwGgUPZaM" - # Verify regular file ID was kept unchanged - assert updated_input[0]["content"][2]["file_id"] == regular_file_id - # Verify text content was preserved - assert updated_input[0]["content"][1]["text"] == "Compare these files" - - -def test_update_responses_input_with_model_file_id_mapping(): - """ - Test that update_responses_input_with_model_file_ids correctly uses - model_file_id_mapping to map managed file IDs to provider-specific file IDs. - """ - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - update_responses_input_with_model_file_ids, - ) - - # Managed file ID (unified) - managed_file_id = "litellm_proxy_file_123" - - # Model file ID mapping - model_file_id_mapping = { - managed_file_id: { - "model_id_1": "openai_file_abc", - "model_id_2": "azure_file_xyz", - } - } - - input_data = [ - { - "role": "user", - "content": [ - { - "type": "input_file", - "file_id": managed_file_id, - }, - { - "type": "input_text", - "text": "Analyze this file", - }, - ], - } - ] - - # Update input with model_id_1 mapping - updated_input = update_responses_input_with_model_file_ids( - input=input_data, - model_id="model_id_1", - model_file_id_mapping=model_file_id_mapping, - ) - - # Verify the file_id was mapped to the correct provider-specific file ID - assert updated_input[0]["content"][0]["file_id"] == "openai_file_abc" - - # Test with different model_id - updated_input_2 = update_responses_input_with_model_file_ids( - input=input_data, - model_id="model_id_2", - model_file_id_mapping=model_file_id_mapping, - ) - - assert updated_input_2[0]["content"][0]["file_id"] == "azure_file_xyz" - - -def test_update_responses_tools_with_model_file_id_mapping(): - """ - Test that update_responses_tools_with_model_file_ids correctly maps - file IDs in code_interpreter tools with container.file_ids. - - This is a regression test for the issue where managed file IDs in - tools.container.file_ids were not being replaced with provider-specific - file IDs, causing "string too long" errors from OpenAI. - """ - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - update_responses_tools_with_model_file_ids, - ) - - # Managed file IDs - managed_file_id_1 = "litellm_proxy_file_123" - managed_file_id_2 = "litellm_proxy_file_456" - - # Model file ID mapping - model_file_id_mapping = { - managed_file_id_1: { - "model_id_1": "openai_file_abc", - }, - managed_file_id_2: { - "model_id_1": "openai_file_def", - }, - } - - tools = [ - { - "type": "code_interpreter", - "container": { - "type": "auto", - "file_ids": [managed_file_id_1, managed_file_id_2], - }, - } - ] - - # Update tools with model mapping - updated_tools = update_responses_tools_with_model_file_ids( - tools=tools, - model_id="model_id_1", - model_file_id_mapping=model_file_id_mapping, - ) - - # Verify the file IDs were mapped to provider-specific file IDs - assert updated_tools[0]["type"] == "code_interpreter" - assert updated_tools[0]["container"]["file_ids"] == [ - "openai_file_abc", - "openai_file_def", - ] - - -def test_update_responses_tools_without_mapping(): - """ - Test that update_responses_tools_with_model_file_ids keeps file IDs - unchanged when no mapping is provided. - """ - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - update_responses_tools_with_model_file_ids, - ) - - regular_file_id = "file-abc123" - - tools = [ - { - "type": "code_interpreter", - "container": { - "type": "auto", - "file_ids": [regular_file_id], - }, - } - ] - - # Update tools without mapping - updated_tools = update_responses_tools_with_model_file_ids( - tools=tools, - model_id=None, - model_file_id_mapping=None, - ) - - # Verify the file ID was kept unchanged - assert updated_tools[0]["container"]["file_ids"] == [regular_file_id] - - -def test_update_responses_tools_with_mixed_file_ids(): - """ - Test that update_responses_tools_with_model_file_ids correctly handles - a mix of managed and regular file IDs. - """ - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - update_responses_tools_with_model_file_ids, - ) - - managed_file_id = "litellm_proxy_file_123" - regular_file_id = "file-abc123" - - model_file_id_mapping = { - managed_file_id: { - "model_id_1": "openai_file_abc", - }, - } - - tools = [ - { - "type": "code_interpreter", - "container": { - "type": "auto", - "file_ids": [managed_file_id, regular_file_id], - }, - } - ] - - # Update tools - updated_tools = update_responses_tools_with_model_file_ids( - tools=tools, - model_id="model_id_1", - model_file_id_mapping=model_file_id_mapping, - ) - - # Verify managed file ID was mapped and regular file ID was kept - assert updated_tools[0]["container"]["file_ids"] == [ - "openai_file_abc", - regular_file_id, - ] - - -def test_get_file_ids_from_responses_tools(): - """ - Test that get_file_ids_from_responses_tools correctly extracts - file IDs from the tools parameter. - """ - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=MagicMock() - ) - - tools = [ - { - "type": "code_interpreter", - "container": { - "type": "auto", - "file_ids": ["file-123", "file-456"], - }, - } - ] - - file_ids = proxy_managed_files.get_file_ids_from_responses_tools(tools) - - assert file_ids == ["file-123", "file-456"] - - -def test_get_file_ids_from_responses_tools_multiple_tools(): - """ - Test that get_file_ids_from_responses_tools handles multiple tools. - """ - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=MagicMock() - ) - - tools = [ - { - "type": "code_interpreter", - "container": { - "type": "auto", - "file_ids": ["file-123"], - }, - }, - { - "type": "file_search", - }, - { - "type": "code_interpreter", - "container": { - "type": "auto", - "file_ids": ["file-456", "file-789"], - }, - }, - ] - - file_ids = proxy_managed_files.get_file_ids_from_responses_tools(tools) - - # Should extract file IDs only from code_interpreter tools - assert file_ids == ["file-123", "file-456", "file-789"] - - -def test_get_file_ids_from_responses_tools_empty(): - """ - Test that get_file_ids_from_responses_tools handles empty or None tools. - """ - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=MagicMock() - ) - - # Test with None - file_ids = proxy_managed_files.get_file_ids_from_responses_tools(None) - assert file_ids == [] - - # Test with empty list - file_ids = proxy_managed_files.get_file_ids_from_responses_tools([]) - assert file_ids == [] - - # Test with tools without file_ids - tools = [{"type": "file_search"}] - file_ids = proxy_managed_files.get_file_ids_from_responses_tools(tools) - assert file_ids == [] - - -@pytest.mark.asyncio -async def test_check_file_ids_access_with_unified_file_ids(): - """ - Test that check_file_ids_access validates user access to managed file IDs. - """ - from litellm.proxy._types import UserAPIKeyAuth - - # Create a unified file ID - unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - regular_file_id = "file-abc123" - - # Mock the access check to return True - prisma_client = AsyncMock() - internal_usage_cache = MagicMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock can_user_call_unified_file_id to return True - proxy_managed_files.can_user_call_unified_file_id = AsyncMock(return_value=True) - - user_api_key_dict = UserAPIKeyAuth( - user_id="test_user_123", - parent_otel_span=MagicMock(), - ) - - # Should not raise an exception for accessible files - await proxy_managed_files.check_file_ids_access( - [unified_file_id, regular_file_id], - user_api_key_dict, - ) - - # Verify can_user_call_unified_file_id was called for the unified file ID - proxy_managed_files.can_user_call_unified_file_id.assert_called_once_with( - unified_file_id, user_api_key_dict - ) - - -@pytest.mark.asyncio -async def test_check_file_ids_access_denied(): - """ - Test that check_file_ids_access raises HTTPException when user doesn't have access. - """ - from litellm.proxy._types import UserAPIKeyAuth - - unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - - prisma_client = AsyncMock() - internal_usage_cache = MagicMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock can_user_call_unified_file_id to return False (access denied) - proxy_managed_files.can_user_call_unified_file_id = AsyncMock(return_value=False) - - user_api_key_dict = UserAPIKeyAuth( - user_id="test_user_123", - parent_otel_span=MagicMock(), - ) - - # Should raise HTTPException with 403 status code - with pytest.raises(HTTPException) as exc_info: - await proxy_managed_files.check_file_ids_access( - [unified_file_id], - user_api_key_dict, - ) - - assert exc_info.value.status_code == 403 - assert "does not have access to the file" in exc_info.value.detail - - -@pytest.mark.asyncio -async def test_check_file_ids_access_with_regular_files_only(): - """ - Test that check_file_ids_access doesn't check access for regular (non-unified) file IDs. - """ - from litellm.proxy._types import UserAPIKeyAuth - - regular_file_id_1 = "file-abc123" - regular_file_id_2 = "file-xyz789" - - prisma_client = AsyncMock() - internal_usage_cache = MagicMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock can_user_call_unified_file_id (should not be called for regular files) - proxy_managed_files.can_user_call_unified_file_id = AsyncMock() - - user_api_key_dict = UserAPIKeyAuth( - user_id="test_user_123", - parent_otel_span=MagicMock(), - ) - - # Should not raise exception and should not call can_user_call_unified_file_id - await proxy_managed_files.check_file_ids_access( - [regular_file_id_1, regular_file_id_2], - user_api_key_dict, - ) - - # Verify can_user_call_unified_file_id was NOT called - proxy_managed_files.can_user_call_unified_file_id.assert_not_called() - - -@pytest.mark.asyncio -async def test_completion_with_file_access_check(): - """ - Test that completion call type checks file access before processing. - """ - from litellm.proxy._types import UserAPIKeyAuth - - unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - - prisma_client = AsyncMock() - prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(return_value=None) - - internal_usage_cache = MagicMock() - internal_usage_cache.async_get_cache = AsyncMock(return_value=None) - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock the get_model_file_id_mapping to return empty dict - proxy_managed_files.get_model_file_id_mapping = AsyncMock(return_value={}) - - # Mock access check to allow access - proxy_managed_files.can_user_call_unified_file_id = AsyncMock(return_value=True) - - user_api_key_dict = UserAPIKeyAuth( - user_id="test_user_123", - parent_otel_span=MagicMock(), - ) - - data = { - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this file?"}, - { - "type": "file", - "file": {"file_id": unified_file_id}, - }, - ], - } - ], - "model": "gpt-5.5", - } - - # Should not raise exception - result = await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=user_api_key_dict, - cache=DualCache(), - data=data, - call_type="acompletion", - ) - - # Verify access check was called - proxy_managed_files.can_user_call_unified_file_id.assert_called_once() - - -@pytest.mark.asyncio -async def test_responses_with_file_access_check(): - """ - Test that responses API checks file access for files in both input and tools. - """ - from litellm.proxy._types import UserAPIKeyAuth - - unified_file_id_1 = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" - unified_file_id_2 = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsNzc3Nzc3Nzc7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1YWVo7bGxtX291dHB1dF9maWxlX21vZGVsX2lkLG1vZGVsXzEyMw" - - prisma_client = AsyncMock() - prisma_client.db.litellm_managedfiletable.find_first = AsyncMock(return_value=None) - - internal_usage_cache = MagicMock() - internal_usage_cache.async_get_cache = AsyncMock(return_value=None) - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock the get_model_file_id_mapping to return empty dict - proxy_managed_files.get_model_file_id_mapping = AsyncMock(return_value={}) - - # Mock access check to allow access - proxy_managed_files.can_user_call_unified_file_id = AsyncMock(return_value=True) - - user_api_key_dict = UserAPIKeyAuth( - user_id="test_user_123", - parent_otel_span=MagicMock(), - ) - - data = { - "input": [ - { - "role": "user", - "content": [ - {"type": "input_text", "text": "Analyze this"}, - {"type": "input_file", "file_id": unified_file_id_1}, - ], - } - ], - "tools": [ - { - "type": "code_interpreter", - "container": { - "type": "auto", - "file_ids": [unified_file_id_2], - }, - } - ], - "model": "gpt-5.5", - } - - # Should not raise exception - result = await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=user_api_key_dict, - cache=DualCache(), - data=data, - call_type="aresponses", - ) - - # Verify access check was called for both file IDs - assert proxy_managed_files.can_user_call_unified_file_id.call_count == 2 - - -@pytest.mark.asyncio -async def test_store_unified_file_id_with_none_file_object(): - """ - Test that store_unified_file_id works when file_object is None - (e.g., for batch output files that are stored before file metadata is available). - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - prisma_client.db.litellm_managedfiletable.create = AsyncMock( - return_value=MagicMock() - ) - internal_usage_cache = MagicMock() - internal_usage_cache.async_set_cache = AsyncMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Store with file_object=None (simulating batch output file storage) - await proxy_managed_files.store_unified_file_id( - file_id="test-unified-file-id", - file_object=None, - litellm_parent_otel_span=None, - model_mappings={"model-123": "file-provider-xyz"}, - user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), - ) - - # Verify DB create was called with expected data (without file_object) - prisma_client.db.litellm_managedfiletable.create.assert_called_once() - call_args = prisma_client.db.litellm_managedfiletable.create.call_args - assert call_args.kwargs["data"]["unified_file_id"] == "test-unified-file-id" - assert "file_object" not in call_args.kwargs["data"] - - -@pytest.mark.asyncio -async def test_afile_delete_returns_provider_response_when_stored_file_object_none(): - """ - Test that afile_delete returns the provider's delete response when the - stored file_object is None (e.g., for batch output files). - """ - from litellm.types.llms.openai import OpenAIFileObject - - unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsdGVzdC1pZDt0YXJnZXRfbW9kZWxfbmFtZXMsZ3B0LTRvO2xsbV9vdXRwdXRfZmlsZV9pZCxmaWxlLXByb3ZpZGVyLXh5ejtsbG1fb3V0cHV0X2ZpbGVfbW9kZWxfaWQsbW9kZWwtMTIz" - - prisma_client = AsyncMock() - db_record = MagicMock() - db_record.model_mappings = '{"model-123": "file-provider-xyz"}' - prisma_client.db.litellm_managedfiletable.find_first = AsyncMock( - return_value=db_record - ) - prisma_client.db.litellm_managedfiletable.delete = AsyncMock() - - internal_usage_cache = MagicMock() - internal_usage_cache.async_get_cache = AsyncMock( - return_value={ - "unified_file_id": unified_file_id, - "model_mappings": {"model-123": "file-provider-xyz"}, - "flat_model_file_ids": ["file-provider-xyz"], - "file_object": None, - "created_by": "test-user", - "updated_by": "test-user", - } - ) - internal_usage_cache.async_set_cache = AsyncMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock the delete_unified_file_id to return None (simulating file_object=None) - proxy_managed_files.delete_unified_file_id = AsyncMock(return_value=None) - - # Mock router response - provider_delete_response = OpenAIFileObject( - id="file-provider-xyz", - object="file", - bytes=1234, - created_at=1234567890, - filename="test.jsonl", - purpose="batch", - ) - - mock_router = MagicMock() - mock_router.afile_delete = AsyncMock(return_value=provider_delete_response) - - result = await proxy_managed_files.afile_delete( - file_id=unified_file_id, - litellm_parent_otel_span=None, - llm_router=mock_router, - ) - - # Should return the provider response with the unified file ID - assert result is not None - assert result.id == unified_file_id - - -@pytest.mark.asyncio -async def test_afile_retrieve_fetches_from_provider_when_file_object_none(): - """ - Test that afile_retrieve fetches from the provider when the stored - file_object is None (e.g., for batch output files). - """ - from litellm.types.llms.openai import OpenAIFileObject - - prisma_client = AsyncMock() - internal_usage_cache = MagicMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock get_unified_file_id to return a stored object with file_object=None - stored_file = MagicMock() - stored_file.file_object = None - stored_file.model_mappings = {"model-123": "file-provider-xyz"} - proxy_managed_files.get_unified_file_id = AsyncMock(return_value=stored_file) - - # Mock the router and provider response - provider_file_response = OpenAIFileObject( - id="file-provider-xyz", - object="file", - bytes=5678, - created_at=1234567890, - filename="output.jsonl", - purpose="batch_output", - ) - - mock_router = MagicMock() - mock_router.get_deployment_credentials_with_provider = MagicMock( - return_value={ - "api_key": "test-key", - "api_base": "https://api.openai.com", - } - ) - - with patch("litellm.afile_retrieve", new_callable=AsyncMock) as mock_afile_retrieve: - mock_afile_retrieve.return_value = provider_file_response - - unified_file_id = "test-unified-file-id" - result = await proxy_managed_files.afile_retrieve( - file_id=unified_file_id, - litellm_parent_otel_span=None, - llm_router=mock_router, - ) - - # Should return the provider response with the unified file ID - assert result is not None - assert result.id == unified_file_id - mock_afile_retrieve.assert_called_once() - - -@pytest.mark.asyncio -async def test_afile_retrieve_raises_error_when_no_router_and_file_object_none(): - """ - Test that afile_retrieve raises an appropriate error when file_object is None - and no llm_router is provided to fetch from the provider. - """ - prisma_client = AsyncMock() - internal_usage_cache = MagicMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock get_unified_file_id to return a stored object with file_object=None - stored_file = MagicMock() - stored_file.file_object = None - stored_file.model_mappings = {"model-123": "file-provider-xyz"} - proxy_managed_files.get_unified_file_id = AsyncMock(return_value=stored_file) - - unified_file_id = "test-unified-file-id" - - with pytest.raises(Exception) as exc_info: - await proxy_managed_files.afile_retrieve( - file_id=unified_file_id, - litellm_parent_otel_span=None, - llm_router=None, - ) - - assert "llm_router is required" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_afile_retrieve_returns_stored_file_object_when_exists(): - """ - Test that afile_retrieve returns the stored file_object directly when it exists - (the normal case for user-uploaded files). - """ - from litellm.types.llms.openai import OpenAIFileObject - - prisma_client = AsyncMock() - internal_usage_cache = MagicMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock get_unified_file_id to return a stored object WITH file_object - stored_file_object = OpenAIFileObject( - id="test-unified-file-id", - object="file", - bytes=1234, - created_at=1234567890, - filename="input.jsonl", - purpose="batch", - ) - stored_file = MagicMock() - stored_file.file_object = stored_file_object - proxy_managed_files.get_unified_file_id = AsyncMock(return_value=stored_file) - - result = await proxy_managed_files.afile_retrieve( - file_id="test-unified-file-id", - litellm_parent_otel_span=None, - llm_router=None, - ) - - # Should return the stored file object directly - assert result == stored_file_object - - -@pytest.mark.asyncio -async def test_afile_retrieve_raises_error_for_non_managed_file(): - """ - Test that afile_retrieve raises an error when the file_id is not found - in the managed files table. - """ - prisma_client = AsyncMock() - internal_usage_cache = MagicMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - internal_usage_cache=internal_usage_cache, - prisma_client=prisma_client, - ) - - # Mock get_unified_file_id to return None (file not found) - proxy_managed_files.get_unified_file_id = AsyncMock(return_value=None) - - with pytest.raises(Exception) as exc_info: - await proxy_managed_files.afile_retrieve( - file_id="non-existent-file-id", - litellm_parent_otel_span=None, - ) - - assert "not found" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_list_batches_from_managed_objects_table(): - from openai.types.batch import BatchRequestCounts - - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - batch_record_1 = MagicMock() - batch_record_1.unified_object_id = "unified-batch-id-1" - batch_record_1.file_object = json.dumps( - { - "id": "batch_abc123", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "completed", - "created_at": 1234567890, - "input_file_id": "file-input-1", - "request_counts": {"total": 1, "completed": 1, "failed": 0}, - } - ) - - batch_record_2 = MagicMock() - batch_record_2.unified_object_id = "unified-batch-id-2" - batch_record_2.file_object = json.dumps( - { - "id": "batch_xyz789", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "in_progress", - "created_at": 1234567891, - "input_file_id": "file-input-2", - "request_counts": {"total": 5, "completed": 2, "failed": 0}, - } - ) - - prisma_client.db.litellm_managedobjecttable.find_many.return_value = [ - batch_record_1, - batch_record_2, - ] - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - result = await proxy_managed_files.list_user_batches( - user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), - limit=10, - ) - - assert result["object"] == "list" - assert len(result["data"]) == 2 - assert result["data"][0].id == "unified-batch-id-1" - assert result["data"][1].id == "unified-batch-id-2" - assert result["first_id"] == "unified-batch-id-1" - assert result["last_id"] == "unified-batch-id-2" - - # Should filter by user_id (created_by) - prisma_client.db.litellm_managedobjecttable.find_many.assert_called_once_with( - where={"file_purpose": "batch", "created_by": "test-user"}, - take=10, - order={"created_at": "desc"}, - ) - - -@pytest.mark.asyncio -async def test_list_batches_from_managed_objects_table_empty_list(): - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - prisma_client.db.litellm_managedobjecttable.find_many.return_value = [] - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - result = await proxy_managed_files.list_user_batches( - user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), - ) - - assert result["object"] == "list" - assert len(result["data"]) == 0 - assert result["first_id"] is None - assert result["last_id"] is None - assert result["has_more"] is False - - # Verify where clause includes created_by filter - # Default take is 20 when no limit is provided - prisma_client.db.litellm_managedobjecttable.find_many.assert_called_once_with( - where={"file_purpose": "batch", "created_by": "test-user"}, - take=20, - order={"created_at": "desc"}, - ) - - -def _create_unified_batch_id(model_id: str, batch_id: str) -> str: - import base64 - - unified_str = f"litellm_proxy;model_id:{model_id};llm_batch_id:{batch_id}" - return base64.urlsafe_b64encode(unified_str.encode()).decode().rstrip("=") - - -@pytest.mark.asyncio -async def test_list_batches_from_managed_objects_table_provider_filter_raises_exception(): - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - # Filtering by provider should raise Exception - with pytest.raises(Exception) as exc_info: - await proxy_managed_files.list_user_batches( - user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), - limit=10, - provider="openai", - ) - - assert str(exc_info.value) == ( - "Filtering by 'provider' is not supported when using managed batches." - ) - - # Verify find_many was NOT called since exception is raised before database query - prisma_client.db.litellm_managedobjecttable.find_many.assert_not_called() - - -@pytest.mark.asyncio -async def test_list_batches_from_managed_objects_table_target_model_name_filter_raises_exception(): - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - # Filtering by provider should raise Exception - with pytest.raises(Exception) as exc_info: - await proxy_managed_files.list_user_batches( - user_api_key_dict=UserAPIKeyAuth(user_id="test-user"), - limit=10, - target_model_names="gpt-5.5,gpt-3.5", - ) - - assert str(exc_info.value) == ( - "Filtering by 'target_model_names' is not supported when using managed batches." - ) - - # Verify find_many was NOT called since exception is raised before database query - prisma_client.db.litellm_managedobjecttable.find_many.assert_not_called() - - -@pytest.mark.asyncio -async def test_list_batches_from_managed_objects_table_filters_by_created_by(): - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Create batch for user1 - batch_user1 = MagicMock() - batch_user1.unified_object_id = "unified-batch-user1" - batch_user1.file_object = json.dumps( - { - "id": "batch_user1_abc", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "completed", - "created_at": 1234567890, - "input_file_id": "file-input-user1", - "request_counts": {"total": 1, "completed": 1, "failed": 0}, - } - ) - - # Create batch for user2 - batch_user2 = MagicMock() - batch_user2.unified_object_id = "unified-batch-user2" - batch_user2.file_object = json.dumps( - { - "id": "batch_user2_xyz", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "completed", - "created_at": 1234567891, - "input_file_id": "file-input-user2", - "request_counts": {"total": 2, "completed": 2, "failed": 0}, - } - ) - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - # Query with user1's API key - should only return user1's batch - prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user1] - result_user1 = await proxy_managed_files.list_user_batches( - user_api_key_dict=UserAPIKeyAuth(user_id="user1"), - limit=10, - ) - - assert len(result_user1["data"]) == 1 - assert result_user1["data"][0].id == "unified-batch-user1" - prisma_client.db.litellm_managedobjecttable.find_many.assert_called_with( - where={"file_purpose": "batch", "created_by": "user1"}, - take=10, - order={"created_at": "desc"}, - ) - - # Query with user2's API key - should only return user2's batch - prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user2] - result_user2 = await proxy_managed_files.list_user_batches( - user_api_key_dict=UserAPIKeyAuth(user_id="user2"), - limit=10, - ) - - assert len(result_user2["data"]) == 1 - assert result_user2["data"][0].id == "unified-batch-user2" - prisma_client.db.litellm_managedobjecttable.find_many.assert_called_with( - where={"file_purpose": "batch", "created_by": "user2"}, - take=10, - order={"created_at": "desc"}, - ) - - -@pytest.mark.asyncio -async def test_return_unified_file_id_includes_expires_at(): - from litellm.types.llms.openai import OpenAIFileObject - - # Create a mock file object with expires_at set - file_object = OpenAIFileObject( - id="file-abc123", - object="file", - bytes=1234, - created_at=1234567890, - filename="test.jsonl", - purpose="batch", - status="uploaded", - expires_at=1234657890, - ) - file_object._hidden_params = {"model_id": "test-model-id"} - - create_file_request = { - "file": ("test.jsonl", b"test content", "application/jsonl"), - "purpose": "batch", - } - - internal_usage_cache = MagicMock() - - result = await _PROXY_LiteLLMManagedFiles.return_unified_file_id( - file_objects=[file_object], - create_file_request=create_file_request, - internal_usage_cache=internal_usage_cache, - litellm_parent_otel_span=None, - target_model_names_list=["gpt-5.5"], - ) - - # Verify expires_at is passed through - assert result.expires_at == 1234657890 - assert result.purpose == "batch" - assert result.filename == "test.jsonl" - assert result.bytes == 1234 - assert result.created_at == 1234567890 - assert _is_base64_encoded_unified_file_id(result.id) - - -# ============================================================================ -# Permission Tests - Cross-User Batch Access -# ============================================================================ -# These tests verify that batches and files created by one user -# cannot be accessed, modified, or cancelled by a different user. -# Reference: https://github.com/BerriAI/litellm/pull/17401/files - - -@pytest.mark.asyncio -async def test_user_b_cannot_retrieve_user_a_batch(): - """ - Test that User B cannot retrieve a batch created by User A. - - This verifies batch isolation between users at the database/hook level. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Mock database to return User A as the creator - batch_record = MagicMock() - batch_record.created_by = "user_a_id" - prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - # User B tries to retrieve User A's batch - unified_batch_id = ( - "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" - ) - - with pytest.raises(HTTPException) as exc_info: - await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_b_id", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"batch_id": unified_batch_id}, - call_type="aretrieve_batch", - ) - - # Should raise 403 Permission Denied - assert exc_info.value.status_code == 403 - - -@pytest.mark.asyncio -async def test_user_b_cannot_cancel_user_a_batch(): - """ - Test that User B cannot cancel a batch created by User A. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Mock database to return User A as the creator - batch_record = MagicMock() - batch_record.created_by = "user_a_id" - prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - # User B tries to cancel User A's batch - unified_batch_id = ( - "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" - ) - - with pytest.raises(HTTPException) as exc_info: - await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_b_id", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"batch_id": unified_batch_id}, - call_type="acancel_batch", - ) - - # Should raise 403 Permission Denied - assert exc_info.value.status_code == 403 - - -@pytest.mark.asyncio -async def test_user_a_can_retrieve_own_batch(): - """ - Test that User A can successfully retrieve their own batch. - - This is a positive test case to ensure permission checks don't block - legitimate access. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Mock database to return User A as the creator - batch_record = MagicMock() - batch_record.created_by = "user_a_id" - prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - # User A retrieves their own batch - unified_batch_id = ( - "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" - ) - - # Should not raise an exception - result = await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_a_id", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"batch_id": unified_batch_id}, - call_type="aretrieve_batch", - ) - - # Should successfully return the decoded batch_id - assert "batch_id" in result - assert result["model"] == "my-model" - - -@pytest.mark.asyncio -async def test_user_b_cannot_retrieve_user_a_file(): - """ - Test that User B cannot retrieve a file created by User A. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Mock database to return User A as the creator - file_record = MagicMock() - file_record.created_by = "user_a_id" - prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - MagicMock(), prisma_client=prisma_client - ) - - # User B tries to retrieve User A's file - unified_file_id = ( - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" - ) - - with pytest.raises(HTTPException) as exc_info: - await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_b_id", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"file_id": unified_file_id}, - call_type="afile_retrieve", - ) - - # Should raise 403 Permission Denied - assert exc_info.value.status_code == 403 - - -@pytest.mark.asyncio -async def test_user_b_cannot_download_user_a_file_content(): - """ - Test that User B cannot download file content for User A's file. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Mock database to return User A as the creator - file_record = MagicMock() - file_record.created_by = "user_a_id" - prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - MagicMock(), prisma_client=prisma_client - ) - - # User B tries to download User A's file content - unified_file_id = ( - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" - ) - - with pytest.raises(HTTPException) as exc_info: - await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_b_id", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"file_id": unified_file_id}, - call_type="afile_content", - ) - - # Should raise 403 Permission Denied - assert exc_info.value.status_code == 403 - - -@pytest.mark.asyncio -async def test_user_b_cannot_delete_user_a_file(): - """ - Test that User B cannot delete a file created by User A. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Mock database to return User A as the creator - file_record = MagicMock() - file_record.created_by = "user_a_id" - prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - MagicMock(), prisma_client=prisma_client - ) - - # User B tries to delete User A's file - unified_file_id = ( - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" - ) - - with pytest.raises(HTTPException) as exc_info: - await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_b_id", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"file_id": unified_file_id}, - call_type="afile_delete", - ) - - # Should raise 403 Permission Denied - assert exc_info.value.status_code == 403 - - -@pytest.mark.asyncio -async def test_user_a_can_retrieve_own_file(): - """ - Test that User A can successfully retrieve their own file. - - Positive test case to ensure permission checks work correctly for the owner. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Mock database to return User A as the creator - file_record = MagicMock() - file_record.created_by = "user_a_id" - file_record.model_mappings = '{"model-123": "file-abc123"}' - file_record.file_object = json.dumps( - { - "id": "file-abc123", - "object": "file", - "bytes": 1234, - "created_at": 1234567890, - "filename": "test.jsonl", - "purpose": "batch", - } - ) - prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - MagicMock(), prisma_client=prisma_client - ) - - # User A retrieves their own file - unified_file_id = ( - "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" - ) - - # Should not raise an exception - result = await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_a_id", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"file_id": unified_file_id}, - call_type="afile_retrieve", - ) - - # Should successfully return the decoded file_id - assert "file_id" in result - - -@pytest.mark.asyncio -async def test_list_batches_only_returns_user_own_batches(): - """ - Test that list_user_batches only returns batches created by the requesting user. - - This ensures users cannot see other users' batches in list operations. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Create batches for User A - batch_user_a = MagicMock() - batch_user_a.unified_object_id = "batch-user-a" - batch_user_a.file_object = json.dumps( - { - "id": "batch_a", - "object": "batch", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "status": "completed", - "created_at": 1234567890, - "input_file_id": "file-a", - "request_counts": {"total": 1, "completed": 1, "failed": 0}, - } - ) - - # Mock database to only return User A's batches - prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user_a] - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - # User A requests their batches - result = await proxy_managed_files.list_user_batches( - user_api_key_dict=UserAPIKeyAuth(user_id="user_a_id"), - limit=10, - ) - - # Should only return User A's batches - assert len(result["data"]) == 1 - assert result["data"][0].id == "batch-user-a" - - # Verify the database query filtered by user_id - prisma_client.db.litellm_managedobjecttable.find_many.assert_called_once_with( - where={"file_purpose": "batch", "created_by": "user_a_id"}, - take=10, - order={"created_at": "desc"}, - ) - - -@pytest.mark.asyncio -async def test_same_user_different_keys_can_access_batch(): - """ - Test that different API keys for the same user can access the same batch. - - This verifies that permission checks are based on user_id, not API key, - allowing users to have multiple keys that can all access their resources. - """ - from litellm.proxy._types import UserAPIKeyAuth - - prisma_client = AsyncMock() - - # Mock database to return the user_id as creator - batch_record = MagicMock() - batch_record.created_by = "user_a_id" - prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record - - proxy_managed_files = _PROXY_LiteLLMManagedFiles( - DualCache(), prisma_client=prisma_client - ) - - unified_batch_id = ( - "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" - ) - - # First API key for User A retrieves the batch - result1 = await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_a_id", api_key="key-1", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"batch_id": unified_batch_id}, - call_type="aretrieve_batch", - ) - - assert "batch_id" in result1 - - # Second API key for the same User A retrieves the batch - result2 = await proxy_managed_files.async_pre_call_hook( - user_api_key_dict=UserAPIKeyAuth( - user_id="user_a_id", api_key="key-2", parent_otel_span=MagicMock() - ), - cache=MagicMock(), - data={"batch_id": unified_batch_id}, - call_type="aretrieve_batch", - ) - - assert "batch_id" in result2 - # Both keys should get the same result - assert result1["batch_id"] == result2["batch_id"] diff --git a/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_internal_user_endpoints.py b/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_internal_user_endpoints.py deleted file mode 100644 index 32685c5cbd..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_internal_user_endpoints.py +++ /dev/null @@ -1,176 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastapi import HTTPException -from fastapi.testclient import TestClient -from litellm_enterprise.proxy.management_endpoints.internal_user_endpoints import router - - -@pytest.fixture -def client(): - from fastapi import FastAPI - - app = FastAPI() - app.include_router(router) - return TestClient(app) - - -@pytest.fixture -def mock_user_api_key_auth(): - """Mock the user_api_key_auth dependency""" - with patch( - "litellm_enterprise.proxy.management_endpoints.internal_user_endpoints.user_api_key_auth" - ) as mock_auth: - mock_auth.return_value = {"user_id": "test_user", "api_key": "test_key"} - yield mock_auth - - -class TestAvailableEnterpriseUsers: - @pytest.mark.asyncio - async def test_available_users_with_max_users_set( - self, client, mock_user_api_key_auth - ): - """Test when max_users is set and user count is within limit""" - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.proxy_server.premium_user", - True, - ), - patch( - "litellm.proxy.proxy_server.premium_user_data", - {"max_users": 10}, - ), - ): - # Mock database count - mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=5) - mock_prisma.db.litellm_teamtable.count = AsyncMock(return_value=2) - - # Override the dependency - client.app.dependency_overrides[mock_user_api_key_auth] = lambda: { - "user_id": "test_user" - } - - response = client.get("/user/available_users") - - assert response.status_code == 200 - data = response.json() - - assert data["total_users"] == 10 - assert data["total_users_used"] == 5 - assert data["total_users_remaining"] == 5 - assert data["total_teams"] is None - assert data["total_teams_used"] == 2 - assert data["total_teams_remaining"] is None - # Ensure no negative values - assert data["total_users_remaining"] >= 0 - - @pytest.mark.asyncio - async def test_available_users_without_max_users_set( - self, client, mock_user_api_key_auth - ): - """Test when max_users is not set (premium_user_data is None or doesn't contain max_users)""" - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.proxy_server.premium_user", - True, - ), - patch( - "litellm.proxy.proxy_server.premium_user_data", - None, - ), - ): - # Mock database count - mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=3) - mock_prisma.db.litellm_teamtable.count = AsyncMock(return_value=1) - - # Override the dependency - client.app.dependency_overrides[mock_user_api_key_auth] = lambda: { - "user_id": "test_user" - } - - response = client.get("/user/available_users") - - assert response.status_code == 200 - data = response.json() - - assert data["total_users"] is None - assert data["total_users_used"] == 3 - assert data["total_users_remaining"] is None - assert data["total_teams"] is None - assert data["total_teams_used"] == 1 - assert data["total_teams_remaining"] is None - - @pytest.mark.asyncio - async def test_available_users_negative_remaining_bug( - self, client, mock_user_api_key_auth - ): - """Test the current bug where total_users_remaining can be negative""" - with ( - patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, - patch( - "litellm.proxy.proxy_server.premium_user", - True, - ), - patch( - "litellm.proxy.proxy_server.premium_user_data", - {"key": "value"}, - ), - ): - # Mock database count higher than max_users to trigger the bug - mock_prisma.db.litellm_usertable.count = AsyncMock(return_value=8) - mock_prisma.db.litellm_teamtable.count = AsyncMock(return_value=3) - - # Override the dependency - client.app.dependency_overrides[mock_user_api_key_auth] = lambda: { - "user_id": "test_user" - } - - response = client.get("/user/available_users") - - assert response.status_code == 200 - data = response.json() - - print(f"data: {data}") - - assert data["total_users"] == None - assert data["total_users_used"] == 8 - assert data["total_teams"] == None - assert data["total_teams_used"] == 3 - # This assertion will fail due to the current bug - remaining is -3 - # TODO: Fix the bug to ensure remaining is never negative - assert data["total_users_remaining"] == None # Current buggy behavior - assert data["total_teams_remaining"] == None # Current buggy behavior - # The following assertion would be the correct behavior: - # assert data["total_users_remaining"] >= 0 - - @pytest.mark.asyncio - async def test_available_users_no_database_connection( - self, client, mock_user_api_key_auth - ): - """Test when prisma_client is None (no database connection)""" - from litellm.proxy._types import CommonProxyErrors - - with ( - patch( - "litellm.proxy.proxy_server.prisma_client", - None, - ), - patch( - "litellm.proxy.proxy_server.premium_user", - True, - ), - ): - # Override the dependency - client.app.dependency_overrides[mock_user_api_key_auth] = lambda: { - "user_id": "test_user" - } - - response = client.get("/user/available_users") - - assert response.status_code == 500 - assert ( - CommonProxyErrors.db_not_connected_error.value - in response.json()["detail"]["error"] - ) diff --git a/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_project_endpoints_prisma.py b/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_project_endpoints_prisma.py deleted file mode 100644 index c55b66b402..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/management_endpoints/test_project_endpoints_prisma.py +++ /dev/null @@ -1,866 +0,0 @@ -import os -import sys -import traceback -from litellm._uuid import uuid -from unittest import mock - -from dotenv import load_dotenv -from fastapi import Request - -load_dotenv() -import time - -sys.path.insert(0, os.path.abspath("../..")) -import logging - -import pytest - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.proxy.management_endpoints.team_endpoints import ( - new_team, -) -from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - new_project, - update_project, - delete_project, - project_info, -) -from litellm.proxy.proxy_server import ( - LitellmUserRoles, -) -from litellm.proxy.utils import PrismaClient, ProxyLogging - -verbose_proxy_logger.setLevel(level=logging.DEBUG) - - -from litellm.caching.caching import DualCache -from litellm.proxy._types import ( - NewProjectRequest, - UpdateProjectRequest, - DeleteProjectRequest, - NewTeamRequest, - UserAPIKeyAuth, -) - -proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) - - -@pytest.fixture -def prisma_client(): - from litellm.proxy.proxy_cli import append_query_params - - ### add connection pool + pool timeout args - params = {"connection_limit": 100, "pool_timeout": 60} - database_url = os.getenv("DATABASE_URL") - modified_url = append_query_params(database_url, params) - os.environ["DATABASE_URL"] = modified_url - - # Assuming PrismaClient is a class that needs to be instantiated - prisma_client = PrismaClient( - database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj - ) - - # Reset litellm.proxy.proxy_server.prisma_client to None - litellm.proxy.proxy_server.litellm_proxy_budget_name = ( - f"litellm-proxy-budget-{time.time()}" - ) - litellm.proxy.proxy_server.user_custom_key_generate = None - - # Enable premium_user for project management tests - setattr(litellm.proxy.proxy_server, "premium_user", True) - - return prisma_client - - -@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).") -@pytest.mark.asyncio -async def test_new_project(prisma_client): - """ - Test creating a new project with budget, models, and metadata. - """ - try: - print("prisma client=", prisma_client) - - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - - await litellm.proxy.proxy_server.prisma_client.connect() - - # Create a team first - _team_id = f"project-test-team_{uuid.uuid4()}" - await new_team( - NewTeamRequest( - team_id=_team_id, - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - # Create a project - project_data = NewProjectRequest( - project_alias="test-project", - description="Test project for unit testing", - team_id=_team_id, - metadata={"use_case_id": "TEST-001", "responsible_ai_id": "RAI-001"}, - models=["gpt-5.5", "gpt-5-mini"], - max_budget=100.0, - model_rpm_limit={"gpt-5.5": 100}, - model_tpm_limit={"gpt-5.5": 1000}, - ) - - response = await new_project( - data=project_data, - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - print("New project response:", response) - - # Assertions - assert response.project_id is not None - assert response.project_alias == "test-project" - assert response.description == "Test project for unit testing" - assert response.team_id == _team_id - assert response.models == ["gpt-5.5", "gpt-5-mini"] - # model_rpm_limit and model_tpm_limit are stored in metadata - assert response.metadata["use_case_id"] == "TEST-001" - assert response.metadata["responsible_ai_id"] == "RAI-001" - assert response.metadata["model_rpm_limit"] == {"gpt-5.5": 100} - assert response.metadata["model_tpm_limit"] == {"gpt-5.5": 1000} - assert response.litellm_budget_table is not None - assert response.litellm_budget_table.max_budget == 100.0 - - except Exception as e: - print("Got Exception", e) - traceback.print_exc() - pytest.fail(f"Got exception {e}") - - -@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).") -@pytest.mark.asyncio -async def test_update_project(prisma_client): - """ - Test updating an existing project's budget, models, and metadata. - """ - try: - print("prisma client=", prisma_client) - - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - - await litellm.proxy.proxy_server.prisma_client.connect() - - # Create a team first - _team_id = f"project-test-team_{uuid.uuid4()}" - await new_team( - NewTeamRequest( - team_id=_team_id, - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - # Create a project - project_data = NewProjectRequest( - project_alias="test-project-update", - description="Original description", - team_id=_team_id, - metadata={ - "use_case_id": "TEST-002", - }, - models=["gpt-5.5"], - max_budget=50.0, - ) - - create_response = await new_project( - data=project_data, - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - print("Created project:", create_response) - project_id = create_response.project_id - - # Update the project - update_data = UpdateProjectRequest( - project_id=project_id, - project_alias="test-project-updated", - description="Updated description", - metadata={ - "use_case_id": "TEST-002-UPDATED", - "additional_field": "new_value", - }, - models=["gpt-5.5", "gpt-5-mini", "claude-3"], - max_budget=200.0, - model_rpm_limit={"gpt-5.5": 200, "claude-3": 50}, - model_tpm_limit={"gpt-5.5": 2000, "claude-3": 500}, - ) - - update_response = await update_project( - data=update_data, - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - print("Updated project response:", update_response) - - # Assertions - assert update_response.project_id == project_id - assert update_response.project_alias == "test-project-updated" - assert update_response.description == "Updated description" - assert update_response.models == ["gpt-5.5", "gpt-5-mini", "claude-3"] - # model_rpm_limit and model_tpm_limit are stored in metadata - assert update_response.metadata["use_case_id"] == "TEST-002-UPDATED" - assert update_response.metadata["additional_field"] == "new_value" - assert update_response.metadata["model_rpm_limit"] == { - "gpt-5.5": 200, - "claude-3": 50, - } - assert update_response.metadata["model_tpm_limit"] == { - "gpt-5.5": 2000, - "claude-3": 500, - } - assert update_response.litellm_budget_table is not None - assert update_response.litellm_budget_table.max_budget == 200.0 - - except Exception as e: - print("Got Exception", e) - traceback.print_exc() - pytest.fail(f"Got exception {e}") - - -@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).") -@pytest.mark.asyncio -async def test_delete_project(prisma_client): - """ - Test deleting a project. - """ - try: - print("prisma client=", prisma_client) - - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - - await litellm.proxy.proxy_server.prisma_client.connect() - - # Create a team first - _team_id = f"project-test-team_{uuid.uuid4()}" - await new_team( - NewTeamRequest( - team_id=_team_id, - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - # Create a project - project_data = NewProjectRequest( - project_alias="test-project-delete", - team_id=_team_id, - models=["gpt-5.5"], - max_budget=50.0, - ) - - create_response = await new_project( - data=project_data, - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - print("Created project:", create_response) - project_id = create_response.project_id - - # Delete the project - delete_data = DeleteProjectRequest(project_ids=[project_id]) - - delete_response = await delete_project( - data=delete_data, - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - print("Delete project response:", delete_response) - - # Assertions - delete_project returns a list of deleted project objects - assert isinstance(delete_response, list) - assert len(delete_response) == 1 - assert delete_response[0].project_id == project_id - - # Try to get info on the deleted project - should fail or return None - try: - await project_info( - project_id=project_id, - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - pytest.fail("Expected to fail when fetching deleted project") - except Exception as e: - print("Expected error when fetching deleted project:", e) - # This is expected behavior - - except Exception as e: - print("Got Exception", e) - traceback.print_exc() - pytest.fail(f"Got exception {e}") - - -@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).") -@pytest.mark.asyncio -async def test_project_info(prisma_client): - """ - Test getting project info. - """ - try: - print("prisma client=", prisma_client) - - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - - await litellm.proxy.proxy_server.prisma_client.connect() - - # Create a team first - _team_id = f"project-test-team_{uuid.uuid4()}" - await new_team( - NewTeamRequest( - team_id=_team_id, - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - # Create a project - project_data = NewProjectRequest( - project_alias="test-project-info", - description="Test project info endpoint", - team_id=_team_id, - metadata={"use_case_id": "TEST-003", "cost_center": "engineering"}, - models=["gpt-5.5", "claude-3"], - max_budget=150.0, - model_rpm_limit={"gpt-5.5": 150}, - model_tpm_limit={"gpt-5.5": 1500}, - ) - - create_response = await new_project( - data=project_data, - http_request=Request(scope={"type": "http"}), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - print("Created project:", create_response) - project_id = create_response.project_id - - # Get project info - info_response = await project_info( - project_id=project_id, - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - print("Project info response:", info_response) - - # Assertions - project_info returns the project object directly - assert info_response.project_id == project_id - assert info_response.project_alias == "test-project-info" - assert info_response.description == "Test project info endpoint" - assert info_response.team_id == _team_id - assert info_response.models == ["gpt-5.5", "claude-3"] - # model_rpm_limit and model_tpm_limit are stored in metadata - assert info_response.metadata["use_case_id"] == "TEST-003" - assert info_response.metadata["cost_center"] == "engineering" - assert info_response.metadata["model_rpm_limit"] == {"gpt-5.5": 150} - assert info_response.metadata["model_tpm_limit"] == {"gpt-5.5": 1500} - assert info_response.litellm_budget_table is not None - assert info_response.litellm_budget_table.max_budget == 150.0 - - except Exception as e: - print("Got Exception", e) - traceback.print_exc() - pytest.fail(f"Got exception {e}") - - -### VALIDATION TESTS ### - - -def test_check_team_project_limits_models_not_in_team(): - """ - Test that creating a project with models not in the team raises an error. - """ - from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - _check_team_project_limits, - ) - from litellm.proxy._types import LiteLLM_TeamTable - - team = LiteLLM_TeamTable( - team_id="test-team", - models=["gpt-5.5", "gpt-5-mini"], - ) - - data = NewProjectRequest( - team_id="test-team", - models=["gpt-5.5", "claude-3"], # claude-3 not in team - ) - - with pytest.raises(Exception) as exc_info: - _check_team_project_limits(team_object=team, data=data) - - assert "claude-3" in str(exc_info.value.detail) - assert "not in team's allowed models" in str(exc_info.value.detail) - - -def test_check_team_project_limits_budget_exceeds_team(): - """ - Test that creating a project with budget > team budget raises an error. - """ - from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - _check_team_project_limits, - ) - from litellm.proxy._types import LiteLLM_TeamTable - - team = LiteLLM_TeamTable( - team_id="test-team", - models=["gpt-5.5"], - max_budget=100.0, - ) - - data = NewProjectRequest( - team_id="test-team", - models=["gpt-5.5"], - max_budget=150.0, # exceeds team's 100.0 - ) - - with pytest.raises(Exception) as exc_info: - _check_team_project_limits(team_object=team, data=data) - - assert "exceeds team's max_budget" in str(exc_info.value.detail) - - -def test_check_team_project_limits_valid_subset(): - """ - Test that a valid project (models subset, budget within limit) passes. - """ - from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - _check_team_project_limits, - ) - from litellm.proxy._types import LiteLLM_TeamTable - - team = LiteLLM_TeamTable( - team_id="test-team", - models=["gpt-5.5", "gpt-5-mini", "claude-3"], - max_budget=1000.0, - ) - - data = NewProjectRequest( - team_id="test-team", - models=["gpt-5.5", "gpt-5-mini"], - max_budget=500.0, - ) - - # Should not raise - _check_team_project_limits(team_object=team, data=data) - - -def test_check_team_project_limits_all_proxy_models(): - """ - Test that team with 'all-proxy-models' allows any project models. - """ - from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - _check_team_project_limits, - ) - from litellm.proxy._types import LiteLLM_TeamTable - - team = LiteLLM_TeamTable( - team_id="test-team", - models=["all-proxy-models"], - ) - - data = NewProjectRequest( - team_id="test-team", - models=["gpt-5.5", "claude-3", "anything-goes"], - ) - - # Should not raise - team allows all models - _check_team_project_limits(team_object=team, data=data) - - -def test_check_team_project_limits_tpm_exceeds_team(): - """ - Test that project tpm_limit exceeding team tpm_limit raises an error. - """ - from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - _check_team_project_limits, - ) - from litellm.proxy._types import LiteLLM_TeamTable - - team = LiteLLM_TeamTable( - team_id="test-team", - models=["gpt-5.5"], - tpm_limit=10000, - ) - - data = NewProjectRequest( - team_id="test-team", - models=["gpt-5.5"], - tpm_limit=20000, # exceeds team's 10000 - ) - - with pytest.raises(Exception) as exc_info: - _check_team_project_limits(team_object=team, data=data) - - assert "exceeds team's tpm_limit" in str(exc_info.value.detail) - - -def test_check_team_project_limits_negative_budget(): - """ - Test that negative budget values raise an error. - """ - from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - _check_team_project_limits, - ) - from litellm.proxy._types import LiteLLM_TeamTable - - team = LiteLLM_TeamTable( - team_id="test-team", - models=["gpt-5.5"], - ) - - data = NewProjectRequest( - team_id="test-team", - models=["gpt-5.5"], - max_budget=-10.0, - ) - - with pytest.raises(Exception) as exc_info: - _check_team_project_limits(team_object=team, data=data) - - assert "cannot be negative" in str(exc_info.value.detail) - - -def test_check_team_project_limits_soft_budget_gte_max(): - """ - Test that soft_budget >= max_budget raises an error. - """ - from litellm_enterprise.proxy.management_endpoints.project_endpoints import ( - _check_team_project_limits, - ) - from litellm.proxy._types import LiteLLM_TeamTable - - team = LiteLLM_TeamTable( - team_id="test-team", - models=["gpt-5.5"], - ) - - data = NewProjectRequest( - team_id="test-team", - models=["gpt-5.5"], - max_budget=100.0, - soft_budget=100.0, # equal to max, should fail - ) - - with pytest.raises(Exception) as exc_info: - _check_team_project_limits(team_object=team, data=data) - - assert "must be strictly lower" in str(exc_info.value.detail) - - -def test_premium_user_gate(): - """ - Test that project endpoints require premium_user=True. - """ - - # This test just validates the premium_user check exists - # The actual endpoint test would need prisma, but we can verify - # the import path works - setattr(litellm.proxy.proxy_server, "premium_user", False) - - # Verify that CommonProxyErrors.not_premium_user exists - from litellm.proxy._types import CommonProxyErrors - - assert hasattr(CommonProxyErrors, "not_premium_user") - - # Reset - setattr(litellm.proxy.proxy_server, "premium_user", True) - - -def test_project_model_access_denied_error_type(): - """ - Test that ProxyErrorTypes.project_model_access_denied exists. - """ - from litellm.proxy._types import ProxyErrorTypes - - assert hasattr(ProxyErrorTypes, "project_model_access_denied") - assert ( - ProxyErrorTypes.project_model_access_denied.value - == "project_model_access_denied" - ) - - # Test the classmethod resolves correctly - result = ProxyErrorTypes.get_model_access_error_type_for_object("project") - assert result == ProxyErrorTypes.project_model_access_denied - - -def test_project_cached_obj_has_last_refreshed_at(): - """ - Test that LiteLLM_ProjectTableCachedObj has last_refreshed_at field - matching LiteLLM_TeamTableCachedObj pattern. - """ - from litellm.proxy._types import ( - LiteLLM_ProjectTableCachedObj, - LiteLLM_ProjectTable, - ) - - # Verify inheritance - assert issubclass(LiteLLM_ProjectTableCachedObj, LiteLLM_ProjectTable) - - # Verify last_refreshed_at field exists and defaults to None - obj = LiteLLM_ProjectTableCachedObj( - project_id="test", - created_by="admin", - updated_by="admin", - ) - assert obj.last_refreshed_at is None - - # Verify it can be set - obj.last_refreshed_at = 1234567890.0 - assert obj.last_refreshed_at == 1234567890.0 - - -@pytest.mark.asyncio -async def test_project_max_budget_check_fires_alert(): - """ - Test that _project_max_budget_check fires a budget alert - when project exceeds its max budget (matches _team_max_budget_check pattern). - """ - from litellm.proxy.auth.auth_checks import _project_max_budget_check - from litellm.proxy._types import ( - LiteLLM_BudgetTable, - LiteLLM_ProjectTableCachedObj, - ) - - project = LiteLLM_ProjectTableCachedObj( - project_id="test-project", - spend=150.0, - created_by="admin", - updated_by="admin", - litellm_budget_table=LiteLLM_BudgetTable(max_budget=100.0), - ) - - valid_token = UserAPIKeyAuth( - token="test-token", - user_id="user-1", - team_id="team-1", - ) - - mock_proxy_logging = mock.AsyncMock(spec=ProxyLogging) - mock_proxy_logging.budget_alerts = mock.AsyncMock() - - with pytest.raises(litellm.BudgetExceededError) as exc_info: - await _project_max_budget_check( - project_object=project, - valid_token=valid_token, - proxy_logging_obj=mock_proxy_logging, - ) - - assert "Project=test-project" in str(exc_info.value) - assert "150.0" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_project_soft_budget_check(): - """ - Test that _project_soft_budget_check triggers alert when soft budget is exceeded. - """ - from litellm.proxy.auth.auth_checks import _project_soft_budget_check - from litellm.proxy._types import ( - LiteLLM_BudgetTable, - LiteLLM_ProjectTableCachedObj, - ) - - project = LiteLLM_ProjectTableCachedObj( - project_id="test-project", - spend=80.0, - created_by="admin", - updated_by="admin", - litellm_budget_table=LiteLLM_BudgetTable(soft_budget=75.0), - ) - - valid_token = UserAPIKeyAuth( - token="test-token", - user_id="user-1", - team_id="team-1", - ) - - mock_proxy_logging = mock.AsyncMock(spec=ProxyLogging) - mock_proxy_logging.budget_alerts = mock.AsyncMock() - - # Should not raise (soft budget only alerts, doesn't block) - await _project_soft_budget_check( - project_object=project, - valid_token=valid_token, - proxy_logging_obj=mock_proxy_logging, - ) - - -@pytest.mark.asyncio -async def test_project_soft_budget_check_no_alert_under_budget(): - """ - Test that _project_soft_budget_check does NOT trigger alert when under soft budget. - """ - from litellm.proxy.auth.auth_checks import _project_soft_budget_check - from litellm.proxy._types import ( - LiteLLM_BudgetTable, - LiteLLM_ProjectTableCachedObj, - ) - - project = LiteLLM_ProjectTableCachedObj( - project_id="test-project", - spend=50.0, - created_by="admin", - updated_by="admin", - litellm_budget_table=LiteLLM_BudgetTable(soft_budget=75.0), - ) - - valid_token = UserAPIKeyAuth( - token="test-token", - user_id="user-1", - team_id="team-1", - ) - - mock_proxy_logging = mock.AsyncMock(spec=ProxyLogging) - mock_proxy_logging.budget_alerts = mock.AsyncMock() - - # Should not raise and should not alert - await _project_soft_budget_check( - project_object=project, - valid_token=valid_token, - proxy_logging_obj=mock_proxy_logging, - ) - - -def test_litellm_entity_type_has_project(): - """ - Test that Litellm_EntityType has PROJECT member for budget alerts. - """ - from litellm.proxy._types import Litellm_EntityType - - assert hasattr(Litellm_EntityType, "PROJECT") - assert Litellm_EntityType.PROJECT.value == "project" - - -@pytest.mark.asyncio -async def test_list_projects_returns_timestamps(): - """ - Test that /project/list returns created_at and updated_at for each project. - """ - from datetime import datetime, timezone - from unittest.mock import AsyncMock, MagicMock, patch - - from litellm_enterprise.proxy.management_endpoints.project_endpoints import list_projects - from litellm.proxy._types import LiteLLM_ProjectTable - - now = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) - - # Build a fake DB row that includes created_at and updated_at - fake_project = MagicMock() - fake_project.model_dump.return_value = { - "project_id": "proj-1", - "project_alias": "test-project", - "team_id": "team-1", - "created_by": "admin", - "updated_by": "admin", - "created_at": now, - "updated_at": now, - "models": [], - "spend": 0.0, - "blocked": False, - "budget_id": None, - "description": None, - "metadata": None, - "model_spend": None, - "model_rpm_limit": None, - "model_tpm_limit": None, - "object_permission_id": None, - "litellm_budget_table": None, - "object_permission": None, - } - # Make the fake row behave like a Pydantic model for FastAPI serialization - fake_project.project_id = "proj-1" - fake_project.created_at = now - fake_project.updated_at = now - - mock_prisma = MagicMock() - mock_prisma.db.litellm_projecttable.find_many = AsyncMock( - return_value=[fake_project] - ) - - with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): - response = await list_projects( - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - assert len(response) == 1 - project = response[0] - assert project.created_at == now - assert project.updated_at == now - - -def test_litellm_project_table_has_timestamp_fields(): - """ - Test that LiteLLM_ProjectTable model includes created_at and updated_at fields, - so the /project/list response_model exposes them. - """ - from litellm.proxy._types import LiteLLM_ProjectTable - - fields = LiteLLM_ProjectTable.model_fields - assert "created_at" in fields, "LiteLLM_ProjectTable must have created_at field" - assert "updated_at" in fields, "LiteLLM_ProjectTable must have updated_at field" diff --git a/tests/enterprise/litellm_enterprise/proxy/test_audit_logging_endpoints.py b/tests/enterprise/litellm_enterprise/proxy/test_audit_logging_endpoints.py deleted file mode 100644 index a0a26c089e..0000000000 --- a/tests/enterprise/litellm_enterprise/proxy/test_audit_logging_endpoints.py +++ /dev/null @@ -1,132 +0,0 @@ -from datetime import datetime, timedelta -from unittest.mock import AsyncMock, patch - -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient -from litellm_enterprise.proxy.audit_logging_endpoints import router as audit_router -from litellm_enterprise.types.proxy.audit_logging_endpoints import AuditLogResponse - -from litellm.proxy._types import UserAPIKeyAuth - -# Create an app with just the audit router for testing -app = FastAPI() -app.include_router(audit_router) -client = TestClient(app) - -# Mock data for testing -MOCK_AUDIT_LOG = { - "id": "test-audit-log-1", - "updated_at": datetime.utcnow(), - "changed_by": "test-user", - "changed_by_api_key": "test-api-key-hash", - "action": "create", - "table_name": "test_table", - "object_id": "test-object-1", - "before_value": None, - "updated_values": {"name": "test", "value": 123}, -} - - -@pytest.fixture -def mock_prisma_client(): - with patch("litellm.proxy.proxy_server.prisma_client") as mock: - mock.db.litellm_auditlog.find_many = AsyncMock() - mock.db.litellm_auditlog.find_unique = AsyncMock() - mock.db.litellm_auditlog.count = AsyncMock() - yield mock - - -@pytest.mark.asyncio -async def test_get_audit_logs(mock_prisma_client): - """Test successful retrieval of audit logs with pagination""" - # Mock the database responses - mock_prisma_client.db.litellm_auditlog.find_many.return_value = [ - AuditLogResponse(**MOCK_AUDIT_LOG) - ] - mock_prisma_client.db.litellm_auditlog.count.return_value = 1 - - # Mock the auth dependency - with patch("litellm.proxy.auth.user_api_key_auth.user_api_key_auth") as mock_auth: - mock_auth.return_value = UserAPIKeyAuth( - api_key="test-key", - user_id="test-user", - team_id=None, - organization_id=None, - user_role="proxy_admin", - ) - - # Make the request - response = client.get("/audit?page=1&page_size=10") - - # Assert response - assert response.status_code == 200 - data = response.json() - assert "audit_logs" in data - assert len(data["audit_logs"]) == 1 - assert data["total"] == 1 - assert data["page"] == 1 - assert data["page_size"] == 10 - assert data["total_pages"] == 1 - - # Verify the audit log data - audit_log = data["audit_logs"][0] - assert audit_log["id"] == MOCK_AUDIT_LOG["id"] - assert audit_log["action"] == MOCK_AUDIT_LOG["action"] - assert audit_log["table_name"] == MOCK_AUDIT_LOG["table_name"] - - -@pytest.mark.asyncio -async def test_get_audit_log_by_id(mock_prisma_client): - """Test successful retrieval of a specific audit log by ID""" - # Mock the database response - mock_prisma_client.db.litellm_auditlog.find_unique.return_value = AuditLogResponse( - **MOCK_AUDIT_LOG - ) - - # Mock the auth dependency - with patch("litellm.proxy.auth.user_api_key_auth.user_api_key_auth") as mock_auth: - mock_auth.return_value = UserAPIKeyAuth( - api_key="test-key", - user_id="test-user", - team_id=None, - organization_id=None, - user_role="proxy_admin", - ) - - # Make the request - response = client.get(f"/audit/{MOCK_AUDIT_LOG['id']}") - - # Assert response - assert response.status_code == 200 - data = response.json() - assert data["id"] == MOCK_AUDIT_LOG["id"] - assert data["action"] == MOCK_AUDIT_LOG["action"] - assert data["table_name"] == MOCK_AUDIT_LOG["table_name"] - assert data["object_id"] == MOCK_AUDIT_LOG["object_id"] - - -@pytest.mark.asyncio -async def test_get_audit_log_by_id_not_found(mock_prisma_client): - """Test error handling when audit log is not found""" - # Mock the database response to return None - mock_prisma_client.db.litellm_auditlog.find_unique.return_value = None - - # Mock the auth dependency - with patch("litellm.proxy.auth.user_api_key_auth.user_api_key_auth") as mock_auth: - mock_auth.return_value = UserAPIKeyAuth( - api_key="test-key", - user_id="test-user", - team_id=None, - organization_id=None, - user_role="proxy_admin", - ) - - # Make the request - response = client.get("/audit/non-existent-id") - - # Assert response - assert response.status_code == 404 - data = response.json() - assert "message" in data["detail"] - assert "not found" in data["detail"]["message"].lower()