[Feat] Backend - Add support for disabling callbacks in request body (#12762)
* allow using standard_callback_dynamic_params to disable callbacks * fix is_callback_disabled_dynamically * test_callback_disabled_via_request_body_multiple
This commit is contained in:
parent
657ca3b81a
commit
06574a72b5
@ -1,3 +1,5 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import X_LITELLM_DISABLE_CALLBACKS
|
||||
@ -6,15 +8,18 @@ from litellm.litellm_core_utils.llm_request_utils import (
|
||||
get_proxy_server_request_headers,
|
||||
)
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class EnterpriseCallbackControls:
|
||||
@staticmethod
|
||||
def is_callback_disabled_via_headers(
|
||||
callback: litellm.CALLBACK_TYPES, litellm_params: dict
|
||||
def is_callback_disabled_dynamically(
|
||||
callback: litellm.CALLBACK_TYPES,
|
||||
litellm_params: dict,
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a callback is disabled via the x-litellm-disable-callbacks header.
|
||||
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
|
||||
|
||||
Args:
|
||||
callback: The callback to check (can be string, CustomLogger instance, or callable)
|
||||
@ -28,8 +33,7 @@ class EnterpriseCallbackControls:
|
||||
)
|
||||
|
||||
try:
|
||||
request_headers = get_proxy_server_request_headers(litellm_params)
|
||||
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
|
||||
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
|
||||
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
|
||||
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
|
||||
if disabled_callbacks is not None:
|
||||
@ -39,7 +43,6 @@ class EnterpriseCallbackControls:
|
||||
if not EnterpriseCallbackControls._premium_user_check():
|
||||
return False
|
||||
#########################################################
|
||||
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
|
||||
if isinstance(callback, str):
|
||||
if callback.lower() in disabled_callbacks:
|
||||
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
|
||||
@ -56,6 +59,29 @@ class EnterpriseCallbackControls:
|
||||
f"Error checking disabled callbacks header: {str(e)}"
|
||||
)
|
||||
return False
|
||||
@staticmethod
|
||||
def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
|
||||
"""
|
||||
Get the disabled callbacks from the standard callback dynamic params.
|
||||
"""
|
||||
|
||||
#########################################################
|
||||
# check if disabled via headers
|
||||
#########################################################
|
||||
request_headers = get_proxy_server_request_headers(litellm_params)
|
||||
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
|
||||
if disabled_callbacks is not None:
|
||||
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
|
||||
return list(disabled_callbacks)
|
||||
|
||||
|
||||
#########################################################
|
||||
# check if disabled via request body
|
||||
#########################################################
|
||||
if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
|
||||
return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _premium_user_check():
|
||||
|
||||
@ -1313,8 +1313,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
# Check for dynamically disabled callbacks via headers
|
||||
if (
|
||||
EnterpriseCallbackControls is not None
|
||||
and EnterpriseCallbackControls.is_callback_disabled_via_headers(
|
||||
callback, litellm_params
|
||||
and EnterpriseCallbackControls.is_callback_disabled_dynamically(
|
||||
callback=callback,
|
||||
litellm_params=litellm_params,
|
||||
standard_callback_dynamic_params = self.standard_callback_dynamic_params
|
||||
)
|
||||
):
|
||||
verbose_logger.debug(
|
||||
|
||||
@ -2086,6 +2086,7 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
|
||||
|
||||
# Logging settings
|
||||
turn_off_message_logging: Optional[bool] # when true will not log messages
|
||||
litellm_disabled_callbacks: Optional[List[str]]
|
||||
|
||||
|
||||
all_litellm_params = [
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import unittest.mock as mock
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -13,6 +14,7 @@ from litellm.integrations.langfuse.langfuse_prompt_management import (
|
||||
LangfusePromptManagement,
|
||||
)
|
||||
from litellm.integrations.s3_v2 import S3Logger
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class TestEnterpriseCallbackControls:
|
||||
@ -39,131 +41,145 @@ class TestEnterpriseCallbackControls:
|
||||
"""Test that 'langfuse' string callback is disabled when specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers("langfuse", litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_langfuse_customlogger(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that LangfusePromptManagement CustomLogger instance is disabled when 'langfuse' specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
langfuse_logger = LangfusePromptManagement()
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers(langfuse_logger, litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(langfuse_logger, litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_s3_v2_string(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that 's3_v2' string callback is disabled when specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "s3_v2"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers("s3_v2", litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_s3_v2_customlogger(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that S3Logger CustomLogger instance is disabled when 's3_v2' specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "s3_v2"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Mock S3Logger to avoid async initialization issues
|
||||
with patch('litellm.integrations.s3_v2.S3Logger.__init__', return_value=None):
|
||||
s3_logger = S3Logger()
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers(s3_logger, litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(s3_logger, litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_datadog_string(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that 'datadog' string callback is disabled when specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "datadog"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers("datadog", litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_datadog_customlogger(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that DataDogLogger CustomLogger instance is disabled when 'datadog' specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "datadog"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Mock DataDogLogger to avoid async initialization issues
|
||||
with patch('litellm.integrations.datadog.datadog.DataDogLogger.__init__', return_value=None):
|
||||
datadog_logger = DataDogLogger()
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers(datadog_logger, litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(datadog_logger, litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_multiple_callbacks_disabled(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that multiple callbacks can be disabled with comma-separated list"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse,datadog,s3_v2"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Test each callback is disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("langfuse", litellm_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("datadog", litellm_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("s3_v2", litellm_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
|
||||
|
||||
# Test non-disabled callback is not disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("prometheus", litellm_params) is False
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("prometheus", litellm_params, standard_callback_dynamic_params) is False
|
||||
|
||||
def test_callback_not_disabled_when_not_in_list(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callbacks not in the disabled list are not disabled"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers("datadog", litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_callback_not_disabled_when_no_header(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callbacks are not disabled when the header is not present"""
|
||||
mock_request_headers.return_value = {}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers("langfuse", litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_callback_not_disabled_when_header_none(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callbacks are not disabled when the header value is None"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: None}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers("langfuse", litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_non_premium_user_cannot_disable_callbacks(self, mock_non_premium_user, mock_request_headers):
|
||||
"""Test that non-premium users cannot disable callbacks even with the header"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers("langfuse", litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_case_insensitive_callback_matching(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callback matching is case insensitive"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "LANGFUSE,DataDog"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Test lowercase callbacks are disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("langfuse", litellm_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("datadog", litellm_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
|
||||
|
||||
def test_whitespace_handling_in_disabled_callbacks(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that whitespace around callback names is handled correctly"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: " langfuse , datadog , s3_v2 "}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("langfuse", litellm_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("datadog", litellm_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_via_headers("s3_v2", litellm_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
|
||||
|
||||
def test_custom_logger_not_in_registry(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that CustomLogger not in registry is not disabled"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "unknown_logger"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Create a mock CustomLogger that's not in the registry
|
||||
class UnknownLogger(CustomLogger):
|
||||
pass
|
||||
|
||||
unknown_logger = UnknownLogger()
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers(unknown_logger, litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(unknown_logger, litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_exception_handling(self, mock_premium_user, mock_request_headers):
|
||||
@ -171,6 +187,30 @@ class TestEnterpriseCallbackControls:
|
||||
# Make get_proxy_server_request_headers raise an exception
|
||||
mock_request_headers.side_effect = Exception("Test exception")
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_via_headers("langfuse", litellm_params)
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_callback_disabled_via_request_body_langfuse(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callbacks can be disabled via request body litellm_disabled_callbacks"""
|
||||
mock_request_headers.return_value = {} # No headers
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams(litellm_disabled_callbacks=["langfuse"])
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_via_request_body_multiple(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that multiple callbacks can be disabled via request body"""
|
||||
mock_request_headers.return_value = {} # No headers
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams(litellm_disabled_callbacks=["langfuse", "datadog", "s3_v2"])
|
||||
|
||||
# Test each callback is disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
|
||||
|
||||
# Test non-disabled callback is not disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("prometheus", litellm_params, standard_callback_dynamic_params) is False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user