diff --git a/.circleci/config.yml b/.circleci/config.yml index 9fb1aeaf4f..26660e109e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -16,6 +16,14 @@ commands: echo "nameserver 127.0.0.11" | sudo tee /etc/resolv.conf echo "nameserver 8.8.8.8" | sudo tee -a /etc/resolv.conf echo "nameserver 8.8.4.4" | sudo tee -a /etc/resolv.conf + setup_litellm_enterprise_pip: + steps: + - run: + name: "Install local version of litellm-enterprise" + command: | + cd enterprise + python -m pip install -e . + cd .. jobs: # Add Windows testing job @@ -111,6 +119,7 @@ jobs: pip install "pytest-xdist==3.6.1" pip install "websockets==13.1.0" pip uninstall posthog -y + - setup_litellm_enterprise_pip - save_cache: paths: - ./venv @@ -228,6 +237,7 @@ jobs: pip install "Pillow==10.3.0" pip install "jsonschema==4.22.0" pip install "websockets==13.1.0" + - setup_litellm_enterprise_pip - save_cache: paths: - ./venv @@ -334,6 +344,7 @@ jobs: pip install "Pillow==10.3.0" pip install "jsonschema==4.22.0" pip install "websockets==13.1.0" + - setup_litellm_enterprise_pip - save_cache: paths: - ./venv @@ -445,6 +456,7 @@ jobs: pip install "pytest-retry==1.6.3" pip install "pytest-asyncio==0.21.1" # Run pytest and generate JUnit XML report + - setup_litellm_enterprise_pip - run: name: Run tests command: | @@ -589,6 +601,7 @@ jobs: pip install "jsonschema==4.22.0" pip install "pytest-postgresql==7.0.1" pip install "fakeredis==2.28.1" + - setup_litellm_enterprise_pip - save_cache: paths: - ./venv @@ -861,7 +874,7 @@ jobs: pip install "mcp==1.5.0" pip install "requests-mock>=1.12.1" pip install "responses==0.25.7" - + - setup_litellm_enterprise_pip # Run pytest and generate JUnit XML report - run: name: Run tests @@ -1088,6 +1101,7 @@ jobs: pip install "google-cloud-aiplatform==1.43.0" pip install "mlflow==2.17.2" # Run pytest and generate JUnit XML report + - setup_litellm_enterprise_pip - run: name: Run tests command: | @@ -1135,6 +1149,7 @@ jobs: pip install "tokenizers==0.20.0" pip install "uvloop==0.21.0" pip install jsonschema + - setup_litellm_enterprise_pip - run: name: Run tests command: | diff --git a/enterprise/dist/litellm_enterprise-0.1.1-py3-none-any.whl b/enterprise/dist/litellm_enterprise-0.1.1-py3-none-any.whl new file mode 100644 index 0000000000..d9a8ef41e6 Binary files /dev/null and b/enterprise/dist/litellm_enterprise-0.1.1-py3-none-any.whl differ diff --git a/enterprise/dist/litellm_enterprise-0.1.1.tar.gz b/enterprise/dist/litellm_enterprise-0.1.1.tar.gz new file mode 100644 index 0000000000..98cf132b21 Binary files /dev/null and b/enterprise/dist/litellm_enterprise-0.1.1.tar.gz differ diff --git a/enterprise/litellm_enterprise/__init__.py b/enterprise/litellm_enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/enterprise/enterprise_callbacks/example_logging_api.py b/enterprise/litellm_enterprise/enterprise_callbacks/example_logging_api.py similarity index 83% rename from enterprise/enterprise_callbacks/example_logging_api.py rename to enterprise/litellm_enterprise/enterprise_callbacks/example_logging_api.py index 2084ffb548..14d34f5d1e 100644 --- a/enterprise/enterprise_callbacks/example_logging_api.py +++ b/enterprise/litellm_enterprise/enterprise_callbacks/example_logging_api.py @@ -7,11 +7,11 @@ app = FastAPI() @app.post("/log-event") async def log_event(request: Request): try: - print("Received /log-event request") # noqa + print("Received /log-event request") # noqa # Assuming the incoming request has JSON data data = await request.json() - print("Received request data:") # noqa - print(data) # noqa + print("Received request data:") # noqa + print(data) # noqa # Your additional logic can go here # For now, just printing the received data diff --git a/enterprise/enterprise_callbacks/generic_api_callback.py b/enterprise/litellm_enterprise/enterprise_callbacks/generic_api_callback.py similarity index 100% rename from enterprise/enterprise_callbacks/generic_api_callback.py rename to enterprise/litellm_enterprise/enterprise_callbacks/generic_api_callback.py diff --git a/enterprise/enterprise_callbacks/send_emails/base_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py similarity index 97% rename from enterprise/enterprise_callbacks/send_emails/base_email.py rename to enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py index e476852b2e..d9f3ce46ba 100644 --- a/enterprise/enterprise_callbacks/send_emails/base_email.py +++ b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/base_email.py @@ -105,9 +105,9 @@ class BaseEmailLogger(CustomLogger): support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", self.DEFAULT_SUPPORT_EMAIL) base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000") - recipient_email: Optional[str] = ( - user_email or await self._lookup_user_email_from_db(user_id=user_id) - ) + recipient_email: Optional[ + str + ] = user_email or await self._lookup_user_email_from_db(user_id=user_id) if recipient_email is None: raise ValueError( f"User email not found for user_id: {user_id}. User email is required to send email." diff --git a/enterprise/enterprise_callbacks/send_emails/endpoints.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/endpoints.py similarity index 100% rename from enterprise/enterprise_callbacks/send_emails/endpoints.py rename to enterprise/litellm_enterprise/enterprise_callbacks/send_emails/endpoints.py diff --git a/enterprise/enterprise_callbacks/send_emails/resend_email.py b/enterprise/litellm_enterprise/enterprise_callbacks/send_emails/resend_email.py similarity index 100% rename from enterprise/enterprise_callbacks/send_emails/resend_email.py rename to enterprise/litellm_enterprise/enterprise_callbacks/send_emails/resend_email.py diff --git a/enterprise/proxy/enterprise_routes.py b/enterprise/proxy/enterprise_routes.py index d7aff1952b..8dcc1c77d2 100644 --- a/enterprise/proxy/enterprise_routes.py +++ b/enterprise/proxy/enterprise_routes.py @@ -1,7 +1,9 @@ from fastapi import APIRouter from fastapi.responses import Response +from litellm_enterprise.enterprise_callbacks.send_emails.endpoints import ( + router as email_events_router, +) -from ..enterprise_callbacks.send_emails.endpoints import router as email_events_router from .utils import _should_block_robots from .vector_stores.endpoints import router as vector_stores_router diff --git a/enterprise/pyproject.toml b/enterprise/pyproject.toml new file mode 100644 index 0000000000..a04accb96c --- /dev/null +++ b/enterprise/pyproject.toml @@ -0,0 +1,30 @@ +[tool.poetry] +name = "litellm-enterprise" +version = "0.1.1" +description = "Package for LiteLLM Enterprise features" +authors = ["BerriAI"] +readme = "README.md" + + +[tool.poetry.urls] +homepage = "https://litellm.ai" +Homepage = "https://litellm.ai" +repository = "https://github.com/BerriAI/litellm" +Repository = "https://github.com/BerriAI/litellm" +documentation = "https://docs.litellm.ai" +Documentation = "https://docs.litellm.ai" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0, !=3.9.7" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.commitizen] +version = "0.1.18" +version_files = [ + "pyproject.toml:version", + "../requirements.txt:litellm-enterprise==", + "../pyproject.toml:litellm-enterprise = {version = \"" +] \ No newline at end of file diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 89702e0345..354cde5739 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -134,8 +134,10 @@ from .initialize_dynamic_callback_params import ( from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache try: - from enterprise.enterprise_callbacks.generic_api_callback import GenericAPILogger - from enterprise.enterprise_callbacks.send_emails.resend_email import ( + from litellm_enterprise.enterprise_callbacks.generic_api_callback import ( + GenericAPILogger, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ( ResendEmailLogger, ) except Exception as e: @@ -257,9 +259,9 @@ class Logging(LiteLLMLoggingBaseClass): self.litellm_trace_id: str = litellm_trace_id or str(uuid.uuid4()) self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[Any] = ( - [] - ) # for generating complete stream response + self.sync_streaming_chunks: List[ + Any + ] = [] # for generating complete stream response self.log_raw_request_response = log_raw_request_response # Initialize dynamic callbacks @@ -610,9 +612,9 @@ class Logging(LiteLLMLoggingBaseClass): if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook( non_default_params ): - self.model_call_details["prompt_integration"] = ( - anthropic_cache_control_logger.__class__.__name__ - ) + self.model_call_details[ + "prompt_integration" + ] = anthropic_cache_control_logger.__class__.__name__ return anthropic_cache_control_logger ######################################################### @@ -630,9 +632,9 @@ class Logging(LiteLLMLoggingBaseClass): ), ) ) - self.model_call_details["prompt_integration"] = ( - vector_store_custom_logger.__class__.__name__ - ) + self.model_call_details[ + "prompt_integration" + ] = vector_store_custom_logger.__class__.__name__ return vector_store_custom_logger return None @@ -684,9 +686,9 @@ class Logging(LiteLLMLoggingBaseClass): model ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model - self.model_call_details["litellm_params"]["api_base"] = ( - self._get_masked_api_base(additional_args.get("api_base", "")) - ) + self.model_call_details["litellm_params"][ + "api_base" + ] = self._get_masked_api_base(additional_args.get("api_base", "")) def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 # Log the exact input to the LLM API @@ -715,10 +717,10 @@ class Logging(LiteLLMLoggingBaseClass): try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - _metadata["raw_request"] = ( - "redacted by litellm. \ + _metadata[ + "raw_request" + ] = "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" - ) else: curl_command = self._get_request_curl_command( api_base=additional_args.get("api_base", ""), @@ -729,32 +731,32 @@ class Logging(LiteLLMLoggingBaseClass): _metadata["raw_request"] = str(curl_command) # split up, so it's easier to parse in the UI - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ignore_sensitive_headers=True, - ), - error=None, - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ignore_sensitive_headers=True, + ), + error=None, ) except Exception as e: - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - error=str(e), - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + error=str(e), ) - _metadata["raw_request"] = ( - "Unable to Log \ + _metadata[ + "raw_request" + ] = "Unable to Log \ raw request: {}".format( - str(e) - ) + str(e) ) if self.logger_fn and callable(self.logger_fn): try: @@ -1085,9 +1087,9 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None try: @@ -1112,9 +1114,9 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None @@ -1174,9 +1176,9 @@ class Logging(LiteLLMLoggingBaseClass): end_time = datetime.datetime.now() if self.completion_start_time is None: self.completion_start_time = end_time - self.model_call_details["completion_start_time"] = ( - self.completion_start_time - ) + self.model_call_details[ + "completion_start_time" + ] = self.completion_start_time self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit @@ -1256,39 +1258,39 @@ class Logging(LiteLLMLoggingBaseClass): "response_cost" ] else: - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=logging_result) - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=logging_result) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=logging_result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=logging_result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) elif isinstance(result, dict) or isinstance(result, list): ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) elif standard_logging_object is not None: - self.model_call_details["standard_logging_object"] = ( - standard_logging_object - ) + self.model_call_details[ + "standard_logging_object" + ] = standard_logging_object else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None @@ -1344,23 +1346,23 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details["complete_streaming_response"] = ( - complete_streaming_response - ) - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=complete_streaming_response) - ) + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=complete_streaming_response) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_success_callbacks, @@ -1680,10 +1682,10 @@ class Logging(LiteLLMLoggingBaseClass): ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] openMeterLogger.log_success_event( @@ -1723,10 +1725,10 @@ class Logging(LiteLLMLoggingBaseClass): ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] @@ -1833,9 +1835,9 @@ class Logging(LiteLLMLoggingBaseClass): if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details["async_complete_streaming_response"] = ( - complete_streaming_response - ) + self.model_call_details[ + "async_complete_streaming_response" + ] = complete_streaming_response try: if self.model_call_details.get("cache_hit", False) is True: self.model_call_details["response_cost"] = 0.0 @@ -1845,10 +1847,10 @@ class Logging(LiteLLMLoggingBaseClass): model_call_details=self.model_call_details ) # base_model defaults to None if not set on model_info - self.model_call_details["response_cost"] = ( - self._response_cost_calculator( - result=complete_streaming_response - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator( + result=complete_streaming_response ) verbose_logger.debug( @@ -1861,16 +1863,16 @@ class Logging(LiteLLMLoggingBaseClass): self.model_call_details["response_cost"] = None ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_async_success_callbacks, @@ -2076,18 +2078,18 @@ class Logging(LiteLLMLoggingBaseClass): ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=str(exception), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, ) return start_time, end_time @@ -2861,9 +2863,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 endpoint=arize_config.endpoint, ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"space_key={arize_config.space_key},api_key={arize_config.api_key}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}" for callback in _in_memory_loggers: if ( isinstance(callback, ArizeLogger) @@ -2887,9 +2889,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - arize_phoenix_config.otlp_auth_headers - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = arize_phoenix_config.otlp_auth_headers for callback in _in_memory_loggers: if ( @@ -2980,9 +2982,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 exporter="otlp_http", endpoint="https://langtrace.ai/api/trace", ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"api_key={os.getenv('LANGTRACE_API_KEY')}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" for callback in _in_memory_loggers: if ( isinstance(callback, OpenTelemetry) @@ -3526,10 +3528,10 @@ class StandardLoggingPayloadSetup: for key in StandardLoggingHiddenParams.__annotations__.keys(): if key in hidden_params: if key == "additional_headers": - clean_hidden_params["additional_headers"] = ( - StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] - ) + clean_hidden_params[ + "additional_headers" + ] = StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] ) else: clean_hidden_params[key] = hidden_params[key] # type: ignore @@ -3901,9 +3903,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): ): for k, v in metadata["user_api_key_metadata"].items(): if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[k] = ( - "scrubbed_by_litellm_for_sensitive_keys" - ) + cleaned_user_api_key_metadata[ + k + ] = "scrubbed_by_litellm_for_sensitive_keys" else: cleaned_user_api_key_metadata[k] = v diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py index 047dd97692..082b506ddd 100644 --- a/litellm/proxy/hooks/key_management_event_hooks.py +++ b/litellm/proxy/hooks/key_management_event_hooks.py @@ -297,9 +297,10 @@ class KeyManagementEventHooks: @staticmethod async def _send_key_created_email(response: dict): - from enterprise.enterprise_callbacks.send_emails.base_email import ( + from litellm_enterprise.enterprise_callbacks.send_emails.base_email import ( BaseEmailLogger, ) + from litellm.proxy.proxy_server import general_settings, proxy_logging_obj from litellm.types.enterprise.enterprise_callbacks.send_emails import ( SendKeyCreatedEmailEvent, diff --git a/litellm/proxy/hooks/user_management_event_hooks.py b/litellm/proxy/hooks/user_management_event_hooks.py index b56d3143da..eb0b003d23 100644 --- a/litellm/proxy/hooks/user_management_event_hooks.py +++ b/litellm/proxy/hooks/user_management_event_hooks.py @@ -7,10 +7,12 @@ import uuid from datetime import datetime, timezone from typing import Optional +from litellm_enterprise.enterprise_callbacks.send_emails.base_email import ( + BaseEmailLogger, +) from pydantic import BaseModel import litellm -from enterprise.enterprise_callbacks.send_emails.base_email import BaseEmailLogger from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ( AUDIT_ACTIONS, diff --git a/poetry.lock b/poetry.lock index afd8290a83..f6b6d5cd38 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1709,6 +1709,18 @@ files = [ importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} referencing = ">=0.31.0" +[[package]] +name = "litellm-enterprise" +version = "0.1.1" +description = "Package for LiteLLM Enterprise features" +optional = true +python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" +groups = ["main"] +markers = "extra == \"proxy\"" +files = [ + {file = "litellm_enterprise-0.1.1.tar.gz", hash = "sha256:58465200b1ab8e8c3b5e8a4ba08267502ac35dc42bc05e3a388575d02a5219b6"}, +] + [[package]] name = "litellm-proxy-extras" version = "0.1.18" @@ -4669,9 +4681,9 @@ type = ["pytest-mypy"] [extras] extra-proxy = ["azure-identity", "azure-keyvault-secrets", "google-cloud-kms", "prisma", "redisvl", "resend"] -proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi", "fastapi-sso", "gunicorn", "litellm-proxy-extras", "mcp", "orjson", "pynacl", "python-multipart", "pyyaml", "rich", "rq", "uvicorn", "uvloop", "websockets"] +proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi", "fastapi-sso", "gunicorn", "litellm-enterprise", "litellm-proxy-extras", "mcp", "orjson", "pynacl", "python-multipart", "pyyaml", "rich", "rq", "uvicorn", "uvloop", "websockets"] [metadata] lock-version = "2.1" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "b0e92d112a8265cbd9a75ad30b2a0bb2ad429ea86ebc34c79f2b4eba03c1d364" +content-hash = "2eb67698c4810b12b9c31ef478721b952236438764b5b52bcb67b97884b2320c" diff --git a/pyproject.toml b/pyproject.toml index 9af558ad30..771bcc8fff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ redisvl = {version = "^0.4.1", optional = true, markers = "python_version >= '3. mcp = {version = "1.5.0", optional = true, python = ">=3.10"} litellm-proxy-extras = {version = "0.1.18", optional = true} rich = {version = "13.7.1", optional = true} +litellm-enterprise = {version = "0.1.1", optional = true} [tool.poetry.extras] proxy = [ @@ -78,6 +79,7 @@ proxy = [ "boto3", "mcp", "litellm-proxy-extras", + "litellm-enterprise", "rich", ] diff --git a/requirements.txt b/requirements.txt index b803516d3e..08ed4c7534 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,4 +52,8 @@ tenacity==8.2.3 # for retrying requests, when litellm.num_retries set pydantic==2.10.2 # proxy + openai req. jsonschema==4.22.0 # validating json schema websockets==13.1.0 # for realtime API -#### + +######################## +# LITELLM ENTERPRISE DEPENDENCIES +######################## +litellm-enterprise==0.1.1 diff --git a/tests/code_coverage_tests/liccheck.ini b/tests/code_coverage_tests/liccheck.ini index c5ecc898f2..e3fd3922ca 100644 --- a/tests/code_coverage_tests/liccheck.ini +++ b/tests/code_coverage_tests/liccheck.ini @@ -86,4 +86,5 @@ detect-secrets: >=1.5.0 # MIT License importlib-metadata: >=6.8.0 # Apache 2.0 License tokenizers: >=0.20.2 # Apache 2.0 License jinja2: >=3.1.4 # BSD 3-Clause License -litellm-proxy-extras: >=0.1.1 # MIT License \ No newline at end of file +litellm-proxy-extras: >=0.1.1 # MIT License +litellm-enterprise: >=0.1.1 # LiteLLM Enterprise License \ No newline at end of file diff --git a/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py b/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py index c46912e3a8..22d553e87c 100644 --- a/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py +++ b/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py @@ -8,7 +8,10 @@ from fastapi.testclient import TestClient sys.path.insert(0, os.path.abspath("../../..")) -from enterprise.enterprise_callbacks.send_emails.base_email import BaseEmailLogger +from litellm_enterprise.enterprise_callbacks.send_emails.base_email import ( + BaseEmailLogger, +) + from litellm.proxy._types import Litellm_EntityType, WebhookEvent from litellm.types.enterprise.enterprise_callbacks.send_emails import ( SendKeyCreatedEmailEvent, diff --git a/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_endpoints.py b/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_endpoints.py index 2012b728b7..8169c457c2 100644 --- a/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_endpoints.py +++ b/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_endpoints.py @@ -9,7 +9,7 @@ from fastapi.testclient import TestClient sys.path.insert(0, os.path.abspath("../../..")) -from enterprise.enterprise_callbacks.send_emails.endpoints import ( +from litellm_enterprise.enterprise_callbacks.send_emails.endpoints import ( _get_email_settings, _save_email_settings, get_email_event_settings, @@ -17,6 +17,7 @@ from enterprise.enterprise_callbacks.send_emails.endpoints import ( router, update_event_settings, ) + from litellm.types.enterprise.enterprise_callbacks.send_emails import ( DefaultEmailSettings, EmailEvent, @@ -146,7 +147,7 @@ async def test_get_email_event_settings(mock_prisma_client, mock_user_api_key_au # Setup mocks with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client): with mock.patch( - "enterprise.enterprise_callbacks.send_emails.endpoints._get_email_settings", + "litellm_enterprise.enterprise_callbacks.send_emails.endpoints._get_email_settings", side_effect=mock_get_settings, ): # Call the endpoint function directly @@ -185,11 +186,11 @@ async def test_update_event_settings(mock_prisma_client, mock_user_api_key_auth) # Setup mocks with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client): with mock.patch( - "enterprise.enterprise_callbacks.send_emails.endpoints._get_email_settings", + "litellm_enterprise.enterprise_callbacks.send_emails.endpoints._get_email_settings", side_effect=mock_get_settings, ): with mock.patch( - "enterprise.enterprise_callbacks.send_emails.endpoints._save_email_settings", + "litellm_enterprise.enterprise_callbacks.send_emails.endpoints._save_email_settings", side_effect=mock_save_settings, ): # Create request with updated settings @@ -227,7 +228,7 @@ async def test_reset_event_settings(mock_prisma_client, mock_user_api_key_auth): # Setup mocks with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client): with mock.patch( - "enterprise.enterprise_callbacks.send_emails.endpoints._save_email_settings", + "litellm_enterprise.enterprise_callbacks.send_emails.endpoints._save_email_settings", side_effect=mock_save_settings, ): # Call the endpoint function directly diff --git a/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_resend_email.py b/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_resend_email.py index 980fd49b94..dffa343d99 100644 --- a/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_resend_email.py +++ b/tests/litellm/enterprise/enterprise_callbacks/send_emails/test_resend_email.py @@ -7,7 +7,9 @@ from httpx import Response sys.path.insert(0, os.path.abspath("../../..")) -from enterprise.enterprise_callbacks.send_emails.resend_email import ResendEmailLogger +from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ( + ResendEmailLogger, +) @pytest.fixture @@ -19,7 +21,7 @@ def mock_env_vars(): @pytest.fixture def mock_httpx_client(): with mock.patch( - "enterprise.enterprise_callbacks.send_emails.resend_email.get_async_httpx_client" + "litellm_enterprise.enterprise_callbacks.send_emails.resend_email.get_async_httpx_client" ) as mock_client: # Create a mock response mock_response = mock.AsyncMock(spec=Response) diff --git a/tests/logging_callback_tests/test_generic_api_callback.py b/tests/logging_callback_tests/test_generic_api_callback.py index 0bdec437fe..c033f323ec 100644 --- a/tests/logging_callback_tests/test_generic_api_callback.py +++ b/tests/logging_callback_tests/test_generic_api_callback.py @@ -28,7 +28,7 @@ from litellm.types.utils import ( ) verbose_logger.setLevel(logging.DEBUG) -from enterprise.enterprise_callbacks.generic_api_callback import GenericAPILogger +from litellm_enterprise.enterprise_callbacks.generic_api_callback import GenericAPILogger diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index f6edf71498..71aece0407 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -42,8 +42,8 @@ from litellm.integrations.azure_storage.azure_storage import AzureBlobStorageLog from litellm.integrations.agentops import AgentOps from litellm.integrations.humanloop import HumanloopLogger from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler -from enterprise.enterprise_callbacks.generic_api_callback import GenericAPILogger -from enterprise.enterprise_callbacks.send_emails.resend_email import ResendEmailLogger +from litellm_enterprise.enterprise_callbacks.generic_api_callback import GenericAPILogger +from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ResendEmailLogger from unittest.mock import patch # clear prometheus collectors / registry