fix(proxy): run model-level post_call guardrails on streaming requests (#26922)

This commit is contained in:
michelligabriele 2026-05-07 20:53:03 +02:00 committed by GitHub
parent fee5900acc
commit 9f1b41d206
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 138 additions and 0 deletions

View File

@ -2274,6 +2274,13 @@ class ProxyLogging:
Covers:
1. /chat/completions
"""
from litellm.proxy.proxy_server import llm_router
# Merge model-level guardrails before checking which guardrails to run
request_data = _check_and_merge_model_level_guardrails(
data=request_data, llm_router=llm_router
)
current_response = response
for callback in litellm.callbacks:

View File

@ -294,3 +294,134 @@ async def test_post_call_success_hook_skips_guardrail_not_on_model():
)
assert guardrail.was_called is False
# ---------------------------------------------------------------------------
# Integration test: async_post_call_streaming_iterator_hook with model-level guardrails
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_streaming_iterator_hook_runs_model_level_guardrail():
"""
Model-level guardrails configured on a deployment should execute in
async_post_call_streaming_iterator_hook (streaming path) even when
`default_on: false` and the guardrail is not in the request body.
"""
from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.utils import ProxyLogging
from litellm.types.guardrails import GuardrailEventHooks
class TestStreamingGuardrail(CustomGuardrail):
def __init__(self):
super().__init__(
guardrail_name="test-model-guardrail",
event_hook=GuardrailEventHooks.post_call,
)
self.was_called = False
async def async_post_call_streaming_iterator_hook(
self, user_api_key_dict, response, request_data
):
self.was_called = True
async for chunk in response:
yield chunk
guardrail = TestStreamingGuardrail()
mock_router = MagicMock()
mock_deployment = MagicMock()
mock_deployment.litellm_params.get.return_value = ["test-model-guardrail"]
mock_router.get_deployment.return_value = mock_deployment
async def fake_response():
yield "chunk-1"
yield "chunk-2"
with (
patch("litellm.callbacks", [guardrail]),
patch("litellm.proxy.proxy_server.llm_router", mock_router),
):
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
request_data = {
"model": "gpt-4",
"metadata": {"model_info": {"id": "model-uuid-123"}},
}
user_api_key_dict = UserAPIKeyAuth(api_key="test-key")
chunks = []
async for chunk in proxy_logging.async_post_call_streaming_iterator_hook(
response=fake_response(),
user_api_key_dict=user_api_key_dict,
request_data=request_data,
):
chunks.append(chunk)
assert guardrail.was_called is True
assert chunks == ["chunk-1", "chunk-2"]
@pytest.mark.asyncio
async def test_streaming_iterator_hook_skips_guardrail_not_on_model():
"""
Streaming guardrails NOT configured on the model (and not in the request
body / key / team) should not execute, even after the dispatcher merge
runs. Confirms the gate stays closed for unrelated guardrails.
"""
from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.utils import ProxyLogging
from litellm.types.guardrails import GuardrailEventHooks
class TestStreamingGuardrail(CustomGuardrail):
def __init__(self):
super().__init__(
guardrail_name="unrelated-guardrail",
event_hook=GuardrailEventHooks.post_call,
)
self.was_called = False
async def async_post_call_streaming_iterator_hook(
self, user_api_key_dict, response, request_data
):
self.was_called = True
async for chunk in response:
yield chunk
guardrail = TestStreamingGuardrail()
# Deployment has a DIFFERENT guardrail configured
mock_router = MagicMock()
mock_deployment = MagicMock()
mock_deployment.litellm_params.get.return_value = ["some-other-guardrail"]
mock_router.get_deployment.return_value = mock_deployment
async def fake_response():
yield "chunk-1"
with (
patch("litellm.callbacks", [guardrail]),
patch("litellm.proxy.proxy_server.llm_router", mock_router),
):
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
request_data = {
"model": "gpt-4",
"metadata": {"model_info": {"id": "model-uuid-123"}},
}
user_api_key_dict = UserAPIKeyAuth(api_key="test-key")
chunks = []
async for chunk in proxy_logging.async_post_call_streaming_iterator_hook(
response=fake_response(),
user_api_key_dict=user_api_key_dict,
request_data=request_data,
):
chunks.append(chunk)
assert guardrail.was_called is False
assert chunks == ["chunk-1"]