fix(proxy): run model-level post_call guardrails on streaming requests (#26922)
This commit is contained in:
parent
fee5900acc
commit
9f1b41d206
@ -2274,6 +2274,13 @@ class ProxyLogging:
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
# Merge model-level guardrails before checking which guardrails to run
|
||||
request_data = _check_and_merge_model_level_guardrails(
|
||||
data=request_data, llm_router=llm_router
|
||||
)
|
||||
|
||||
current_response = response
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
|
||||
@ -294,3 +294,134 @@ async def test_post_call_success_hook_skips_guardrail_not_on_model():
|
||||
)
|
||||
|
||||
assert guardrail.was_called is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration test: async_post_call_streaming_iterator_hook with model-level guardrails
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_iterator_hook_runs_model_level_guardrail():
|
||||
"""
|
||||
Model-level guardrails configured on a deployment should execute in
|
||||
async_post_call_streaming_iterator_hook (streaming path) — even when
|
||||
`default_on: false` and the guardrail is not in the request body.
|
||||
"""
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
class TestStreamingGuardrail(CustomGuardrail):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
guardrail_name="test-model-guardrail",
|
||||
event_hook=GuardrailEventHooks.post_call,
|
||||
)
|
||||
self.was_called = False
|
||||
|
||||
async def async_post_call_streaming_iterator_hook(
|
||||
self, user_api_key_dict, response, request_data
|
||||
):
|
||||
self.was_called = True
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
guardrail = TestStreamingGuardrail()
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_deployment = MagicMock()
|
||||
mock_deployment.litellm_params.get.return_value = ["test-model-guardrail"]
|
||||
mock_router.get_deployment.return_value = mock_deployment
|
||||
|
||||
async def fake_response():
|
||||
yield "chunk-1"
|
||||
yield "chunk-2"
|
||||
|
||||
with (
|
||||
patch("litellm.callbacks", [guardrail]),
|
||||
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
||||
):
|
||||
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
||||
request_data = {
|
||||
"model": "gpt-4",
|
||||
"metadata": {"model_info": {"id": "model-uuid-123"}},
|
||||
}
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key="test-key")
|
||||
|
||||
chunks = []
|
||||
async for chunk in proxy_logging.async_post_call_streaming_iterator_hook(
|
||||
response=fake_response(),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=request_data,
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert guardrail.was_called is True
|
||||
assert chunks == ["chunk-1", "chunk-2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_iterator_hook_skips_guardrail_not_on_model():
|
||||
"""
|
||||
Streaming guardrails NOT configured on the model (and not in the request
|
||||
body / key / team) should not execute, even after the dispatcher merge
|
||||
runs. Confirms the gate stays closed for unrelated guardrails.
|
||||
"""
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
class TestStreamingGuardrail(CustomGuardrail):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
guardrail_name="unrelated-guardrail",
|
||||
event_hook=GuardrailEventHooks.post_call,
|
||||
)
|
||||
self.was_called = False
|
||||
|
||||
async def async_post_call_streaming_iterator_hook(
|
||||
self, user_api_key_dict, response, request_data
|
||||
):
|
||||
self.was_called = True
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
guardrail = TestStreamingGuardrail()
|
||||
|
||||
# Deployment has a DIFFERENT guardrail configured
|
||||
mock_router = MagicMock()
|
||||
mock_deployment = MagicMock()
|
||||
mock_deployment.litellm_params.get.return_value = ["some-other-guardrail"]
|
||||
mock_router.get_deployment.return_value = mock_deployment
|
||||
|
||||
async def fake_response():
|
||||
yield "chunk-1"
|
||||
|
||||
with (
|
||||
patch("litellm.callbacks", [guardrail]),
|
||||
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
||||
):
|
||||
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
||||
request_data = {
|
||||
"model": "gpt-4",
|
||||
"metadata": {"model_info": {"id": "model-uuid-123"}},
|
||||
}
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key="test-key")
|
||||
|
||||
chunks = []
|
||||
async for chunk in proxy_logging.async_post_call_streaming_iterator_hook(
|
||||
response=fake_response(),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=request_data,
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert guardrail.was_called is False
|
||||
assert chunks == ["chunk-1"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user