test(proxy/utils): pin ProxyLogging behavior (#29485)

* test(proxy/utils): pin ProxyLogging behavior

Add behavior-pinning tests for the ProxyLogging cluster in
litellm/proxy/utils.py under tests/test_litellm/proxy/utils/proxy_logging/.
Covers InternalUsageCache, _CallbackCapabilities, top-of-file helpers
(print_verbose, _get_email_logger_class, _accepts_litellm_call_info,
_enrich_http_exception_with_guardrail_context), the full ProxyLogging
class (lifecycle, MCP-LLM bridging, capability probes, guardrail
pipeline, pre/during/post/streaming hooks, alerting), plus the
bottom-of-region helpers (on_backoff, jsonify_object, _lookup_deprecated_key).

Each pinned symbol has happy-path and error-path coverage; happy paths
use direct dict-equality with three or more keys (or HiddenParams /
Pydantic model_validate where the surface is a Pydantic shape). The
subdirectory carries a local _pin_check.py and _coverage_check.py that
enforce the gate without surfacing numeric thresholds in CI logs.

Wires tests/test_litellm/proxy/utils into the existing test-path block
in .github/workflows/test-unit-proxy-endpoints.yml.

* test(proxy/utils): drop unused mock_httpx_client fixture

Declared in conftest.py but never referenced by any test. Removing
the dead fixture per Greptile P2 feedback.

* test(proxy/utils): drop local-only gate scripts from PR

_pin_check.py and _coverage_check.py are local stopping signals (not
wired into CI, consume a gitignored .pin_list.txt). They served their
purpose telling the engineer when to stop writing tests; the pytest
suite is the artifact that belongs in the repo.

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
yuneng-jiang 2026-06-02 17:45:39 -07:00 committed by GitHub
parent 457f65eff9
commit b175990b4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 3830 additions and 0 deletions

View File

@ -0,0 +1,56 @@
"""Sanity tests for the proxy_logging conftest fixtures.
Excluded from the pin-check by name.
"""
from __future__ import annotations
import pytest
def test_normalize_replaces_volatile_keys(normalize_fn):
raw = {"id": 7, "name": "x", "nested": {"created_at": 1, "value": 2}}
expected = {"id": "<VOLATILE>", "name": "x", "nested": {"created_at": "<VOLATILE>", "value": 2}}
assert normalize_fn(raw) == expected
def test_normalize_handles_lists(normalize_fn):
raw = [{"id": 1}, {"id": 2}]
assert normalize_fn(raw) == [{"id": "<VOLATILE>"}, {"id": "<VOLATILE>"}]
def test_mock_dual_cache_is_dual_cache(mock_dual_cache):
from litellm.caching.caching import DualCache
assert isinstance(mock_dual_cache, DualCache)
def test_make_user_api_key_auth_returns_correct_type(make_user_api_key_auth):
from litellm.proxy._types import UserAPIKeyAuth
auth = make_user_api_key_auth()
assert isinstance(auth, UserAPIKeyAuth)
assert auth.user_id == "test-user"
def test_make_user_api_key_auth_overrides_apply(make_user_api_key_auth):
auth = make_user_api_key_auth(user_id="custom-id")
assert auth.user_id == "custom-id"
def test_proxy_logging_fixture_is_initialized(proxy_logging):
from litellm.proxy.utils import InternalUsageCache, ProxyLogging
assert isinstance(proxy_logging, ProxyLogging)
assert isinstance(proxy_logging.internal_usage_cache, InternalUsageCache)
assert proxy_logging.proxy_hook_mapping == {}
def test_make_mcp_request_obj_default(make_mcp_request_obj):
obj = make_mcp_request_obj()
assert obj.tool_name == "calculator"
assert obj.arguments == {"x": 1, "y": 2}
def test_mock_router_has_guardrail_list(mock_router):
assert mock_router.guardrail_list == []

View File

@ -0,0 +1,136 @@
"""Shared fixtures for tests/test_litellm/proxy/utils/proxy_logging/.
All fixtures used by PR1 of the proxy/utils.py behavior-pinning project
live here. Tests should not declare fixtures inline.
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from unittest.mock import MagicMock
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[5]))
VOLATILE_KEYS = frozenset(
{
"created_at",
"updated_at",
"id",
"request_id",
"token",
"expires",
"expires_at",
"litellm_call_id",
"key_alias",
"created",
"start_time",
"end_time",
"duration",
"guardrail_start_time",
"guardrail_end_time",
"guardrail_duration",
}
)
def normalize(data: Any, volatile: frozenset = VOLATILE_KEYS) -> Any:
if isinstance(data, dict):
return {
k: ("<VOLATILE>" if k in volatile else normalize(v, volatile))
for k, v in data.items()
}
if isinstance(data, list):
return [normalize(v, volatile) for v in data]
return data
@pytest.fixture
def mock_dual_cache():
from litellm.caching.caching import DualCache
cache = DualCache(default_in_memory_ttl=1)
return cache
@pytest.fixture
def mock_router():
router = MagicMock()
router.guardrail_list = []
router.get_available_guardrail = MagicMock(return_value={"callback": None})
return router
@pytest.fixture
def mock_callbacks_disabled(monkeypatch):
"""Disable all litellm callbacks for the duration of a test."""
import litellm
monkeypatch.setattr(litellm, "callbacks", [])
monkeypatch.setattr(litellm, "success_callback", [])
monkeypatch.setattr(litellm, "failure_callback", [])
monkeypatch.setattr(litellm, "_async_success_callback", [])
monkeypatch.setattr(litellm, "_async_failure_callback", [])
yield
@pytest.fixture
def make_user_api_key_auth():
from litellm.proxy._types import UserAPIKeyAuth
def _make(**overrides) -> UserAPIKeyAuth:
defaults: Dict[str, Any] = {
"api_key": "sk-test-1234",
"user_id": "test-user",
"team_id": "test-team",
"user_role": None,
"max_budget": None,
"spend": 0.0,
}
defaults.update(overrides)
return UserAPIKeyAuth(**defaults)
return _make
@pytest.fixture
def proxy_logging(mock_callbacks_disabled):
"""A wired-up ProxyLogging instance backed by a fresh DualCache.
The fixture leaves it un-started; tests that need ``startup_event``
should call it explicitly with the deps they want to control.
"""
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.utils import ProxyLogging
return ProxyLogging(user_api_key_cache=UserApiKeyCache())
@pytest.fixture
def normalize_fn():
return normalize
@pytest.fixture
def make_mcp_request_obj():
from litellm.types.llms.base import HiddenParams
from litellm.types.mcp import MCPPreCallRequestObject
def _make(
tool_name: str = "calculator",
arguments: Optional[dict] = None,
server_name: Optional[str] = "math-server",
) -> MCPPreCallRequestObject:
return MCPPreCallRequestObject(
tool_name=tool_name,
arguments=arguments if arguments is not None else {"x": 1, "y": 2},
server_name=server_name,
user_api_key_auth={},
hidden_params=HiddenParams(),
)
return _make

View File

@ -0,0 +1,262 @@
"""Pin alerting helpers on ``ProxyLogging``.
Covers ``failed_tracking_alert``, ``budget_alerts``, ``alerting_handler``,
``failure_handler``.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import HTTPException
import litellm
from litellm.proxy._types import AlertType, CallInfo
# ---------------------------------------------------------------------------
# failed_tracking_alert
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_failed_tracking_alert_no_op_when_alerting_none(proxy_logging):
proxy_logging.alerting = None
proxy_logging.slack_alerting_instance = MagicMock(failed_tracking_alert=AsyncMock())
await proxy_logging.failed_tracking_alert(error_message="x", failing_model="m")
proxy_logging.slack_alerting_instance.failed_tracking_alert.assert_not_called()
@pytest.mark.asyncio
async def test_failed_tracking_alert_forwards_to_slack(proxy_logging):
proxy_logging.alerting = ["slack"]
captured: Dict[str, Any] = {}
async def fake_alert(**kwargs):
captured.update(kwargs)
proxy_logging.slack_alerting_instance = MagicMock(failed_tracking_alert=fake_alert)
await proxy_logging.failed_tracking_alert(error_message="db down", failing_model="gpt-4")
snapshot = {
"error_message": captured["error_message"],
"failing_model": captured["failing_model"],
"captured_keys": sorted(captured.keys()),
}
assert snapshot == {
"error_message": "db down",
"failing_model": "gpt-4",
"captured_keys": ["error_message", "failing_model"],
}
@pytest.mark.asyncio
async def test_failed_tracking_alert_slack_error_raises(proxy_logging):
proxy_logging.alerting = ["slack"]
proxy_logging.slack_alerting_instance = MagicMock(
failed_tracking_alert=AsyncMock(side_effect=RuntimeError("slack down"))
)
with pytest.raises(RuntimeError):
await proxy_logging.failed_tracking_alert(error_message="x", failing_model="m")
# ---------------------------------------------------------------------------
# budget_alerts
# ---------------------------------------------------------------------------
def _user_info(alert_emails=None):
return CallInfo(
spend=0.0,
max_budget=1.0,
token="tok",
user_id="u1",
team_id="t1",
team_alias=None,
user_email=None,
key_alias=None,
projected_exceeded_date=None,
projected_spend=None,
event_group="user",
event="threshold_crossed",
alert_emails=alert_emails,
)
@pytest.mark.asyncio
async def test_budget_alerts_no_op_when_alerting_off_and_no_emails(proxy_logging):
proxy_logging.alerting = None
proxy_logging.slack_alerting_instance = MagicMock(budget_alerts=AsyncMock())
proxy_logging.email_logging_instance = MagicMock(budget_alerts=AsyncMock())
await proxy_logging.budget_alerts(type="user_budget", user_info=_user_info())
proxy_logging.slack_alerting_instance.budget_alerts.assert_not_called()
proxy_logging.email_logging_instance.budget_alerts.assert_not_called()
@pytest.mark.asyncio
async def test_budget_alerts_slack_when_slack_alerting(proxy_logging):
proxy_logging.alerting = ["slack"]
captured: Dict[str, Any] = {}
async def fake_alert(**kwargs):
captured.update(kwargs)
proxy_logging.slack_alerting_instance = MagicMock(budget_alerts=fake_alert)
proxy_logging.email_logging_instance = None
user_info = _user_info()
await proxy_logging.budget_alerts(type="user_budget", user_info=user_info)
snapshot = {
"type": captured["type"],
"user_info_is_callinfo": isinstance(captured["user_info"], CallInfo),
"user_id": captured["user_info"].user_id,
}
assert snapshot == {"type": "user_budget", "user_info_is_callinfo": True, "user_id": "u1"}
@pytest.mark.asyncio
async def test_budget_alerts_soft_budget_with_alert_emails_bypasses_global(proxy_logging):
proxy_logging.alerting = None
proxy_logging.slack_alerting_instance = MagicMock(budget_alerts=AsyncMock())
proxy_logging.email_logging_instance = MagicMock(budget_alerts=AsyncMock())
info = _user_info(alert_emails=["a@b.c"])
await proxy_logging.budget_alerts(type="soft_budget", user_info=info)
proxy_logging.email_logging_instance.budget_alerts.assert_called_once()
proxy_logging.slack_alerting_instance.budget_alerts.assert_not_called()
@pytest.mark.asyncio
async def test_budget_alerts_slack_failure_raises(proxy_logging):
proxy_logging.alerting = ["slack"]
proxy_logging.slack_alerting_instance = MagicMock(
budget_alerts=AsyncMock(side_effect=ConnectionError("slack"))
)
proxy_logging.email_logging_instance = None
with pytest.raises(ConnectionError):
await proxy_logging.budget_alerts(type="user_budget", user_info=_user_info())
# ---------------------------------------------------------------------------
# alerting_handler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_alerting_handler_no_op_when_alerting_is_none(proxy_logging):
proxy_logging.alerting = None
proxy_logging.slack_alerting_instance = MagicMock(send_alert=AsyncMock())
await proxy_logging.alerting_handler(message="x", level="High", alert_type=AlertType.db_exceptions)
proxy_logging.slack_alerting_instance.send_alert.assert_not_called()
@pytest.mark.asyncio
async def test_alerting_handler_sends_to_slack(proxy_logging):
proxy_logging.alerting = ["slack"]
captured: Dict[str, Any] = {}
async def fake_send(**kwargs):
captured.update(kwargs)
proxy_logging.slack_alerting_instance = MagicMock(send_alert=fake_send)
await proxy_logging.alerting_handler(
message="hi", level="High", alert_type=AlertType.db_exceptions, request_data={"metadata": {}}
)
snapshot = {
"message": captured["message"],
"level": captured["level"],
"alert_type": captured["alert_type"],
"user_info": captured["user_info"],
}
assert snapshot == {
"message": "hi",
"level": "High",
"alert_type": AlertType.db_exceptions,
"user_info": None,
}
@pytest.mark.asyncio
async def test_alerting_handler_sentry_without_sdk_error_raises(proxy_logging, monkeypatch):
proxy_logging.alerting = ["sentry"]
monkeypatch.setattr(litellm.utils, "sentry_sdk_instance", None)
with pytest.raises(Exception, match="SENTRY_DSN"):
await proxy_logging.alerting_handler(message="x", level="Low", alert_type=AlertType.db_exceptions)
# ---------------------------------------------------------------------------
# failure_handler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_failure_handler_skips_when_db_exceptions_not_in_alert_types(proxy_logging):
proxy_logging.alert_types = ["llm_too_slow"] # type: ignore[list-item]
proxy_logging.alerting_handler = AsyncMock()
proxy_logging.service_logging_obj = MagicMock(async_service_failure_hook=AsyncMock())
await proxy_logging.failure_handler(original_exception=Exception("x"), duration=1.0, call_type="db_read")
proxy_logging.alerting_handler.assert_not_called()
proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called()
@pytest.mark.asyncio
async def test_failure_handler_logs_db_error_and_calls_service_logging(proxy_logging, monkeypatch):
proxy_logging.alert_types = [AlertType.db_exceptions]
proxy_logging.alerting_handler = AsyncMock()
proxy_logging.service_logging_obj = MagicMock(async_service_failure_hook=AsyncMock())
monkeypatch.setattr(litellm.utils, "capture_exception", None)
await proxy_logging.failure_handler(
original_exception=HTTPException(status_code=500, detail="boom"),
duration=1.5,
call_type="db_write",
)
call_kwargs = proxy_logging.service_logging_obj.async_service_failure_hook.call_args.kwargs
snapshot = {
"service": call_kwargs["service"].value if hasattr(call_kwargs["service"], "value") else call_kwargs["service"],
"duration": call_kwargs["duration"],
"call_type": call_kwargs["call_type"],
}
assert snapshot == {
"service": "postgres",
"duration": 1.5,
"call_type": "db_write",
}
@pytest.mark.asyncio
async def test_failure_handler_with_capture_exception_invoked(proxy_logging, monkeypatch):
proxy_logging.alert_types = [AlertType.db_exceptions]
proxy_logging.alerting_handler = AsyncMock()
proxy_logging.service_logging_obj = MagicMock(async_service_failure_hook=AsyncMock())
captured: Dict[str, Any] = {}
def fake_capture(error):
captured["error"] = error
monkeypatch.setattr(litellm.utils, "capture_exception", fake_capture)
err = RuntimeError("real")
await proxy_logging.failure_handler(original_exception=err, duration=1.0, call_type="db_read")
snapshot = {
"captured_is_input": captured["error"] is err,
"service_failure_called": proxy_logging.service_logging_obj.async_service_failure_hook.called,
"alerting_handler_scheduled": proxy_logging.alerting_handler.called,
}
assert snapshot == {
"captured_is_input": True,
"service_failure_called": True,
"alerting_handler_scheduled": True,
}
@pytest.mark.asyncio
async def test_failure_handler_propagates_service_logging_error_raises(proxy_logging, monkeypatch):
proxy_logging.alert_types = [AlertType.db_exceptions]
proxy_logging.alerting_handler = AsyncMock()
proxy_logging.service_logging_obj = MagicMock(
async_service_failure_hook=AsyncMock(side_effect=RuntimeError("svc"))
)
monkeypatch.setattr(litellm.utils, "capture_exception", None)
with pytest.raises(RuntimeError):
await proxy_logging.failure_handler(
original_exception=Exception("x"), duration=0.0, call_type="db_read"
)

View File

@ -0,0 +1,338 @@
"""Pin the ``ProxyLogging`` capability-probe family.
Covers ``_callback_capabilities`` (the cached deriver),
``has_post_call_response_headers_callbacks``, ``has_streaming_callbacks``,
``has_streaming_chunk_hook_overrides``, ``needs_iterator_wrap``,
``needs_per_chunk_streaming_hook``, ``has_during_call_guardrails``, and
``get_combined_callback_list``.
"""
from __future__ import annotations
from typing import Any
import pytest
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.utils import ProxyLogging, _CallbackCapabilities
class _PlainLogger(CustomLogger):
pass
class _OverridesResponseHeaders(CustomLogger):
async def async_post_call_response_headers_hook(self, *args, **kwargs): # type: ignore[override]
return None
class _OverridesIterator(CustomLogger):
async def async_post_call_streaming_iterator_hook(self, *args, **kwargs): # type: ignore[override]
return None
class _OverridesPerChunk(CustomLogger):
async def async_post_call_streaming_hook(self, *args, **kwargs): # type: ignore[override]
return None
class _OverridesPreCall(CustomLogger):
async def async_pre_call_hook(self, *args, **kwargs): # type: ignore[override]
return None
@pytest.fixture(autouse=True)
def _clear_caps_cache():
ProxyLogging._callback_capabilities_cache.clear()
yield
ProxyLogging._callback_capabilities_cache.clear()
def test_callback_capabilities_with_no_callbacks_returns_defaults(mock_callbacks_disabled):
caps = ProxyLogging._callback_capabilities()
snapshot = {
"headers": caps.has_post_call_response_headers,
"iterator": caps.has_iterator_override,
"chunk": caps.has_streaming_chunk_override,
"guardrail": caps.has_guardrail,
"pre_call": caps.has_pre_call_override,
"callbacks": caps.resolved_callbacks,
"overrides": caps.iterator_overrides,
}
assert snapshot == {
"headers": False,
"iterator": False,
"chunk": False,
"guardrail": False,
"pre_call": False,
"callbacks": (),
"overrides": (),
}
def test_callback_capabilities_detects_overrides(monkeypatch):
cb1 = _OverridesResponseHeaders()
cb2 = _OverridesIterator()
cb3 = _OverridesPerChunk()
cb4 = _OverridesPreCall()
monkeypatch.setattr(litellm, "callbacks", [cb1, cb2, cb3, cb4])
caps = ProxyLogging._callback_capabilities()
snapshot = {
"headers": caps.has_post_call_response_headers,
"iterator": caps.has_iterator_override,
"chunk": caps.has_streaming_chunk_override,
"pre_call": caps.has_pre_call_override,
}
assert snapshot == {
"headers": True,
"iterator": True,
"chunk": True,
"pre_call": True,
}
def test_callback_capabilities_caches_result(monkeypatch):
cb = _OverridesResponseHeaders()
monkeypatch.setattr(litellm, "callbacks", [cb])
first = ProxyLogging._callback_capabilities()
second = ProxyLogging._callback_capabilities()
assert first is second
def test_callback_capabilities_invalidates_on_change(monkeypatch):
monkeypatch.setattr(litellm, "callbacks", [_OverridesResponseHeaders()])
first = ProxyLogging._callback_capabilities()
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
second = ProxyLogging._callback_capabilities()
assert first is not second
assert first.has_post_call_response_headers is True
assert second.has_post_call_response_headers is False
assert second.has_iterator_override is True
def test_callback_capabilities_callback_resolution_error_raises(monkeypatch):
monkeypatch.setattr(litellm, "callbacks", ["unknown-string"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"get_custom_logger_compatible_class",
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("bad")),
)
with pytest.raises(RuntimeError):
ProxyLogging._callback_capabilities()
# ---------------------------------------------------------------------------
# Individual capability probes
# ---------------------------------------------------------------------------
def test_has_post_call_response_headers_callbacks_truth_table(monkeypatch, mock_callbacks_disabled):
"""One snapshot covering true + false + cache invalidation."""
snapshot = {
"empty_returns_false": ProxyLogging.has_post_call_response_headers_callbacks(),
}
monkeypatch.setattr(litellm, "callbacks", [_OverridesResponseHeaders()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["override_returns_true"] = ProxyLogging.has_post_call_response_headers_callbacks()
monkeypatch.setattr(litellm, "callbacks", [_PlainLogger()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["plain_logger_false"] = ProxyLogging.has_post_call_response_headers_callbacks()
assert snapshot == {
"empty_returns_false": False,
"override_returns_true": True,
"plain_logger_false": False,
}
def test_has_post_call_response_headers_callbacks_error_when_bad_callback(monkeypatch):
monkeypatch.setattr(litellm, "callbacks", ["x"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"get_custom_logger_compatible_class",
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("kaboom")),
)
with pytest.raises(RuntimeError):
ProxyLogging.has_post_call_response_headers_callbacks()
def test_has_streaming_callbacks_truth_table(monkeypatch, mock_callbacks_disabled):
snapshot = {
"empty_false": ProxyLogging.has_streaming_callbacks(),
}
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["iterator_override_true"] = ProxyLogging.has_streaming_callbacks()
monkeypatch.setattr(litellm, "callbacks", [_OverridesPerChunk()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["per_chunk_override_true"] = ProxyLogging.has_streaming_callbacks()
assert snapshot == {
"empty_false": False,
"iterator_override_true": True,
"per_chunk_override_true": True,
}
def test_has_streaming_callbacks_error_when_resolution_fails(monkeypatch):
monkeypatch.setattr(litellm, "callbacks", ["x"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"get_custom_logger_compatible_class",
lambda *a, **kw: (_ for _ in ()).throw(ValueError("nope")),
)
with pytest.raises(ValueError):
ProxyLogging.has_streaming_callbacks()
def test_has_streaming_chunk_hook_overrides_truth_table(monkeypatch, mock_callbacks_disabled):
snapshot = {
"empty_false": ProxyLogging.has_streaming_chunk_hook_overrides(),
}
monkeypatch.setattr(litellm, "callbacks", [_OverridesPerChunk()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["per_chunk_override_true"] = ProxyLogging.has_streaming_chunk_hook_overrides()
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["only_iterator_false"] = ProxyLogging.has_streaming_chunk_hook_overrides()
assert snapshot == {
"empty_false": False,
"per_chunk_override_true": True,
"only_iterator_false": False,
}
def test_has_streaming_chunk_hook_overrides_error_raises(monkeypatch):
monkeypatch.setattr(litellm, "callbacks", ["x"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"get_custom_logger_compatible_class",
lambda *a, **kw: (_ for _ in ()).throw(TypeError("nope")),
)
with pytest.raises(TypeError):
ProxyLogging.has_streaming_chunk_hook_overrides()
def test_needs_iterator_wrap_truth_table(proxy_logging, monkeypatch, mock_callbacks_disabled):
snapshot = {
"empty_false": proxy_logging.needs_iterator_wrap(),
}
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["with_iter_override_true"] = proxy_logging.needs_iterator_wrap()
monkeypatch.setattr(litellm, "callbacks", [_OverridesPerChunk()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["only_per_chunk_false"] = proxy_logging.needs_iterator_wrap()
assert snapshot == {
"empty_false": False,
"with_iter_override_true": True,
"only_per_chunk_false": False,
}
def test_needs_iterator_wrap_error_raises(proxy_logging, monkeypatch):
monkeypatch.setattr(litellm, "callbacks", ["x"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"get_custom_logger_compatible_class",
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("oops")),
)
with pytest.raises(RuntimeError):
proxy_logging.needs_iterator_wrap()
def test_needs_per_chunk_streaming_hook_truth_table(proxy_logging, monkeypatch, mock_callbacks_disabled):
snapshot = {
"empty_false": proxy_logging.needs_per_chunk_streaming_hook(),
}
monkeypatch.setattr(litellm, "callbacks", [_OverridesPerChunk()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["per_chunk_override_true"] = proxy_logging.needs_per_chunk_streaming_hook()
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["only_iter_override_false"] = proxy_logging.needs_per_chunk_streaming_hook()
assert snapshot == {
"empty_false": False,
"per_chunk_override_true": True,
"only_iter_override_false": False,
}
def test_needs_per_chunk_streaming_hook_error_raises(proxy_logging, monkeypatch):
monkeypatch.setattr(litellm, "callbacks", ["x"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"get_custom_logger_compatible_class",
lambda *a, **kw: (_ for _ in ()).throw(KeyError("oops")),
)
with pytest.raises(KeyError):
proxy_logging.needs_per_chunk_streaming_hook()
def test_has_during_call_guardrails_truth_table(monkeypatch, mock_callbacks_disabled):
from litellm.integrations.custom_guardrail import CustomGuardrail
class _G(CustomGuardrail):
def __init__(self):
super().__init__(guardrail_name="g", event_hook="pre_call")
snapshot = {
"empty_false": ProxyLogging.has_during_call_guardrails(),
}
monkeypatch.setattr(litellm, "callbacks", [_G()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["with_guardrail_true"] = ProxyLogging.has_during_call_guardrails()
monkeypatch.setattr(litellm, "callbacks", [_PlainLogger()])
ProxyLogging._callback_capabilities_cache.clear()
snapshot["only_plain_logger_false"] = ProxyLogging.has_during_call_guardrails()
assert snapshot == {
"empty_false": False,
"with_guardrail_true": True,
"only_plain_logger_false": False,
}
def test_has_during_call_guardrails_resolution_error_raises(monkeypatch):
monkeypatch.setattr(litellm, "callbacks", ["x"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"get_custom_logger_compatible_class",
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("oops")),
)
with pytest.raises(RuntimeError):
ProxyLogging.has_during_call_guardrails()
# ---------------------------------------------------------------------------
# get_combined_callback_list
# ---------------------------------------------------------------------------
def test_get_combined_callback_list_matrix(proxy_logging):
snapshot = {
"merge_dedupes_shared": sorted(
proxy_logging.get_combined_callback_list(
dynamic_success_callbacks=["dyn-1", "shared"],
global_callbacks=["glob-1", "shared"],
)
),
"none_dynamic_returns_global_copy": proxy_logging.get_combined_callback_list(
dynamic_success_callbacks=None, global_callbacks=["a", "b", "c"]
),
"empty_both": proxy_logging.get_combined_callback_list(
dynamic_success_callbacks=[], global_callbacks=[]
),
}
assert snapshot == {
"merge_dedupes_shared": ["dyn-1", "glob-1", "shared"],
"none_dynamic_returns_global_copy": ["a", "b", "c"],
"empty_both": [],
}
def test_get_combined_callback_list_unhashable_dynamic_raises(proxy_logging):
with pytest.raises(TypeError):
proxy_logging.get_combined_callback_list(
dynamic_success_callbacks=[{"unhashable": True}],
global_callbacks=[],
)

View File

@ -0,0 +1,59 @@
"""Pin the ``_CallbackCapabilities`` dataclass shape and defaults."""
from __future__ import annotations
import dataclasses
import pytest
from litellm.proxy.utils import _CallbackCapabilities
def test_callback_capabilities_default_values():
caps = _CallbackCapabilities()
snapshot = {
"has_post_call_response_headers": caps.has_post_call_response_headers,
"has_iterator_override": caps.has_iterator_override,
"has_streaming_chunk_override": caps.has_streaming_chunk_override,
"has_guardrail": caps.has_guardrail,
"has_pre_call_override": caps.has_pre_call_override,
"iterator_overrides": caps.iterator_overrides,
"resolved_callbacks": caps.resolved_callbacks,
}
assert snapshot == {
"has_post_call_response_headers": False,
"has_iterator_override": False,
"has_streaming_chunk_override": False,
"has_guardrail": False,
"has_pre_call_override": False,
"iterator_overrides": (),
"resolved_callbacks": (),
}
def test_callback_capabilities_explicit_values_preserved():
cb1 = object()
cb2 = object()
caps = _CallbackCapabilities(
has_post_call_response_headers=True,
has_iterator_override=True,
has_streaming_chunk_override=False,
has_guardrail=True,
has_pre_call_override=False,
iterator_overrides=((cb1, "override"), (cb2, "apply_guardrail")),
resolved_callbacks=(cb1, cb2),
)
assert caps.has_post_call_response_headers is True
assert caps.iterator_overrides == ((cb1, "override"), (cb2, "apply_guardrail"))
assert caps.resolved_callbacks == (cb1, cb2)
def test_callback_capabilities_is_frozen_error_on_mutation_raises():
caps = _CallbackCapabilities()
with pytest.raises(dataclasses.FrozenInstanceError):
caps.has_post_call_response_headers = True # type: ignore[misc]
def test_callback_capabilities_invalid_field_error_raises():
with pytest.raises(TypeError):
_CallbackCapabilities(unknown_field=True) # type: ignore[call-arg]

View File

@ -0,0 +1,86 @@
"""Pin ``ProxyLogging.during_call_hook``."""
from __future__ import annotations
from typing import Any, Dict
from unittest.mock import AsyncMock, MagicMock
import pytest
import litellm
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy.utils import ProxyLogging
from litellm.types.guardrails import GuardrailEventHooks
@pytest.fixture(autouse=True)
def _clear_caps_cache():
ProxyLogging._callback_capabilities_cache.clear()
yield
ProxyLogging._callback_capabilities_cache.clear()
def _make_guardrail(name="g1", should_run=True, response=None):
cb = MagicMock(spec=CustomGuardrail)
cb.__class__ = CustomGuardrail
cb.guardrail_name = name
cb.event_hook = GuardrailEventHooks.during_call
cb.use_native_during_call_hook = False
cb.should_run_guardrail = MagicMock(return_value=should_run)
cb.async_moderation_hook = AsyncMock(return_value=response)
return cb
@pytest.mark.asyncio
async def test_during_call_hook_no_guardrail_fast_path_returns_data(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
data = {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
out = await proxy_logging.during_call_hook(
data=data,
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
assert out is data
@pytest.mark.asyncio
async def test_during_call_hook_runs_guardrails_in_parallel(proxy_logging, make_user_api_key_auth, monkeypatch):
g1 = _make_guardrail("a")
g2 = _make_guardrail("b")
monkeypatch.setattr(litellm, "callbacks", [g1, g2])
data = {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
out = await proxy_logging.during_call_hook(
data=data,
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
snapshot = {
"out_is_data": out is data,
"a_called": g1.async_moderation_hook.called,
"b_called": g2.async_moderation_hook.called,
}
assert snapshot == {"out_is_data": True, "a_called": True, "b_called": True}
@pytest.mark.asyncio
async def test_during_call_hook_guardrail_skipped_when_should_not_run(proxy_logging, make_user_api_key_auth, monkeypatch):
g = _make_guardrail("g", should_run=False)
monkeypatch.setattr(litellm, "callbacks", [g])
await proxy_logging.during_call_hook(
data={"model": "m"},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
g.async_moderation_hook.assert_not_called()
@pytest.mark.asyncio
async def test_during_call_hook_guardrail_error_raises(proxy_logging, make_user_api_key_auth, monkeypatch):
g = _make_guardrail("bad")
g.async_moderation_hook = AsyncMock(side_effect=RuntimeError("blocked"))
monkeypatch.setattr(litellm, "callbacks", [g])
with pytest.raises(RuntimeError):
await proxy_logging.during_call_hook(
data={"model": "m"},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)

View File

@ -0,0 +1,559 @@
"""Pin ProxyLogging guardrail pipeline helpers.
Covers ``_should_use_guardrail_load_balancing``, ``_execute_guardrail_hook``,
``_execute_guardrail_with_load_balancing``, ``_process_guardrail_callback``,
``_process_prompt_template``, ``_process_guardrail_metadata``,
``_maybe_execute_pipelines``, ``_handle_pipeline_result``,
``_run_guardrail_task_with_enrichment``.
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
import litellm
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
ModifyResponseException,
)
from litellm.proxy.utils import ProxyLogging
from litellm.types.guardrails import GuardrailEventHooks
@pytest.fixture(autouse=True)
def _clear_caps_cache():
ProxyLogging._callback_capabilities_cache.clear()
yield
ProxyLogging._callback_capabilities_cache.clear()
# ---------------------------------------------------------------------------
# _should_use_guardrail_load_balancing
# ---------------------------------------------------------------------------
def test_should_use_guardrail_load_balancing_truth_table(proxy_logging):
snapshot = {}
router = MagicMock()
router.guardrail_list = [{"guardrail_name": "g1"}, {"guardrail_name": "g1"}]
with patch("litellm.proxy.proxy_server.llm_router", router):
snapshot["multiple_deployments"] = proxy_logging._should_use_guardrail_load_balancing("g1")
router.guardrail_list = [{"guardrail_name": "g1"}]
with patch("litellm.proxy.proxy_server.llm_router", router):
snapshot["single_deployment"] = proxy_logging._should_use_guardrail_load_balancing("g1")
with patch("litellm.proxy.proxy_server.llm_router", None):
snapshot["no_router"] = proxy_logging._should_use_guardrail_load_balancing("g1")
router.guardrail_list = [{"guardrail_name": "other"}, {"guardrail_name": "other"}]
with patch("litellm.proxy.proxy_server.llm_router", router):
snapshot["unmatched_name"] = proxy_logging._should_use_guardrail_load_balancing("g1")
assert snapshot == {
"multiple_deployments": True,
"single_deployment": False,
"no_router": False,
"unmatched_name": False,
}
def test_should_use_guardrail_load_balancing_error_on_bad_guardrail_list(proxy_logging):
router = MagicMock()
router.guardrail_list = "not a list"
with patch("litellm.proxy.proxy_server.llm_router", router):
with pytest.raises((TypeError, AttributeError)):
proxy_logging._should_use_guardrail_load_balancing("g1")
# ---------------------------------------------------------------------------
# _execute_guardrail_hook
# ---------------------------------------------------------------------------
def _make_guardrail():
cb = MagicMock(spec=CustomGuardrail)
cb.__class__ = CustomGuardrail
cb.guardrail_name = "g"
cb.event_hook = GuardrailEventHooks.pre_call
cb.use_native_during_call_hook = False
cb.async_pre_call_hook = AsyncMock(return_value={"a": 1, "b": 2, "c": 3})
cb.async_moderation_hook = AsyncMock(return_value={"x": 1, "y": 2, "z": 3})
cb.async_post_call_success_hook = AsyncMock(return_value={"p": 1, "q": 2, "r": 3})
return cb
@pytest.mark.asyncio
async def test_execute_guardrail_hook_pre_call(proxy_logging, make_user_api_key_auth):
cb = _make_guardrail()
out = await proxy_logging._execute_guardrail_hook(
callback=cb,
hook_type="pre_call",
data={"model": "m"},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
assert out == {"a": 1, "b": 2, "c": 3}
@pytest.mark.asyncio
async def test_execute_guardrail_hook_during_call(proxy_logging, make_user_api_key_auth):
cb = _make_guardrail()
out = await proxy_logging._execute_guardrail_hook(
callback=cb,
hook_type="during_call",
data={"model": "m"},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
assert out == {"x": 1, "y": 2, "z": 3}
@pytest.mark.asyncio
async def test_execute_guardrail_hook_post_call(proxy_logging, make_user_api_key_auth):
cb = _make_guardrail()
out = await proxy_logging._execute_guardrail_hook(
callback=cb,
hook_type="post_call",
data={"model": "m"},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
response={"original": True},
)
assert out == {"p": 1, "q": 2, "r": 3}
@pytest.mark.asyncio
async def test_execute_guardrail_hook_unknown_hook_type_raises(proxy_logging, make_user_api_key_auth):
cb = _make_guardrail()
with pytest.raises(ValueError, match="Unknown hook_type"):
await proxy_logging._execute_guardrail_hook(
callback=cb,
hook_type="weird", # type: ignore[arg-type]
data={},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
# ---------------------------------------------------------------------------
# _execute_guardrail_with_load_balancing
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_execute_guardrail_with_load_balancing_routes_through_router(
proxy_logging, make_user_api_key_auth
):
cb = _make_guardrail()
router = MagicMock()
router.get_available_guardrail = MagicMock(return_value={"callback": cb})
with patch("litellm.proxy.proxy_server.llm_router", router):
out = await proxy_logging._execute_guardrail_with_load_balancing(
guardrail_name="g",
hook_type="pre_call",
data={"model": "m"},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
assert out == {"a": 1, "b": 2, "c": 3}
@pytest.mark.asyncio
async def test_execute_guardrail_with_load_balancing_router_none_raises(
proxy_logging, make_user_api_key_auth
):
with patch("litellm.proxy.proxy_server.llm_router", None):
with pytest.raises(ValueError, match="Router not initialized"):
await proxy_logging._execute_guardrail_with_load_balancing(
guardrail_name="g",
hook_type="pre_call",
data={},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
@pytest.mark.asyncio
async def test_execute_guardrail_with_load_balancing_no_callback_raises(
proxy_logging, make_user_api_key_auth
):
router = MagicMock()
router.get_available_guardrail = MagicMock(return_value={"callback": None})
with patch("litellm.proxy.proxy_server.llm_router", router):
with pytest.raises(ValueError, match="No callback found"):
await proxy_logging._execute_guardrail_with_load_balancing(
guardrail_name="g",
hook_type="pre_call",
data={},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
)
# ---------------------------------------------------------------------------
# _process_guardrail_callback
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_process_guardrail_callback_skipped_when_should_run_false(
proxy_logging, make_user_api_key_auth
):
cb = _make_guardrail()
cb.should_run_guardrail = MagicMock(return_value=False)
out = await proxy_logging._process_guardrail_callback(
callback=cb,
data={"model": "m"},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
event_type=GuardrailEventHooks.pre_call,
)
assert out is None
@pytest.mark.asyncio
async def test_process_guardrail_callback_returns_data_on_success(
proxy_logging, make_user_api_key_auth, monkeypatch
):
cb = _make_guardrail()
cb.should_run_guardrail = MagicMock(return_value=True)
proxy_logging._should_use_guardrail_load_balancing = MagicMock(return_value=False)
out = await proxy_logging._process_guardrail_callback(
callback=cb,
data={"model": "m", "messages": [{"role": "user"}], "temperature": 0.1},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
event_type=GuardrailEventHooks.pre_call,
)
assert out == {"a": 1, "b": 2, "c": 3}
@pytest.mark.asyncio
async def test_process_guardrail_callback_enriches_and_reraises_http_exception(
proxy_logging, make_user_api_key_auth, monkeypatch
):
cb = _make_guardrail()
cb.should_run_guardrail = MagicMock(return_value=True)
detail = {"error": "blocked"}
cb.async_pre_call_hook = AsyncMock(side_effect=HTTPException(status_code=400, detail=detail))
cb.event_hook = "pre_call"
proxy_logging._should_use_guardrail_load_balancing = MagicMock(return_value=False)
with pytest.raises(HTTPException):
await proxy_logging._process_guardrail_callback(
callback=cb,
data={"model": "m"},
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
event_type=GuardrailEventHooks.pre_call,
)
assert detail["guardrail_name"] == "g"
# ---------------------------------------------------------------------------
# _process_guardrail_metadata
# ---------------------------------------------------------------------------
def test_process_guardrail_metadata_calls_header_helper(proxy_logging, monkeypatch):
calls: List[Dict[str, Any]] = []
def fake_add(request_data, guardrail_name):
calls.append({"data": request_data, "name": guardrail_name})
from litellm.proxy.common_utils import callback_utils
monkeypatch.setattr(callback_utils, "add_guardrail_to_applied_guardrails_header", fake_add)
data = {"metadata": {"guardrails": ["g1", "g2"]}}
proxy_logging._process_guardrail_metadata(data)
snapshot = {
"call_count": len(calls),
"first_name": calls[0]["name"],
"second_name": calls[1]["name"],
"data_passed_is_input": all(c["data"] is data for c in calls),
}
assert snapshot == {
"call_count": 2,
"first_name": "g1",
"second_name": "g2",
"data_passed_is_input": True,
}
def test_process_guardrail_metadata_skips_already_applied(proxy_logging, monkeypatch):
calls: List[str] = []
def fake_add(request_data, guardrail_name):
calls.append(guardrail_name)
from litellm.proxy.common_utils import callback_utils
monkeypatch.setattr(callback_utils, "add_guardrail_to_applied_guardrails_header", fake_add)
data = {"metadata": {"guardrails": ["g1", "g2"], "applied_guardrails": ["g1"]}}
proxy_logging._process_guardrail_metadata(data)
assert calls == ["g2"]
def test_process_guardrail_metadata_no_metadata_is_noop(proxy_logging, monkeypatch):
from litellm.proxy.common_utils import callback_utils
monkeypatch.setattr(
callback_utils,
"add_guardrail_to_applied_guardrails_header",
MagicMock(side_effect=AssertionError("should not be called")),
)
proxy_logging._process_guardrail_metadata({})
def test_process_guardrail_metadata_invalid_data_raises(proxy_logging):
with pytest.raises(AttributeError):
proxy_logging._process_guardrail_metadata(None) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# _maybe_execute_pipelines
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maybe_execute_pipelines_no_pipelines_returns_data(proxy_logging, make_user_api_key_auth):
data = {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
out = await proxy_logging._maybe_execute_pipelines(
data=data,
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
event_hook="pre_call",
)
assert out == {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
@pytest.mark.asyncio
async def test_maybe_execute_pipelines_skips_pipelines_with_other_mode(proxy_logging, make_user_api_key_auth, monkeypatch):
pipeline = MagicMock()
pipeline.mode = "post_call" # not pre_call
data = {"metadata": {"_guardrail_pipelines": [("p1", pipeline)]}, "model": "m", "messages": []}
executed = MagicMock()
monkeypatch.setattr(
"litellm.proxy.policy_engine.pipeline_executor.PipelineExecutor.execute_steps", executed
)
out = await proxy_logging._maybe_execute_pipelines(
data=data,
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
event_hook="pre_call",
)
executed.assert_not_called()
assert out is data
@pytest.mark.asyncio
async def test_maybe_execute_pipelines_blocks_on_block_terminal_action_raises(
proxy_logging, make_user_api_key_auth, monkeypatch
):
pipeline = MagicMock()
pipeline.mode = "pre_call"
pipeline.steps = []
fake_result = MagicMock()
fake_result.terminal_action = "block"
fake_result.step_results = []
data = {"metadata": {"_guardrail_pipelines": [("policy-1", pipeline)]}, "messages": [], "model": "m"}
async def fake_execute_steps(**kwargs):
return fake_result
monkeypatch.setattr(
"litellm.proxy.policy_engine.pipeline_executor.PipelineExecutor.execute_steps",
fake_execute_steps,
)
with pytest.raises(HTTPException):
await proxy_logging._maybe_execute_pipelines(
data=data,
user_api_key_dict=make_user_api_key_auth(),
call_type="completion",
event_hook="pre_call",
)
# ---------------------------------------------------------------------------
# _handle_pipeline_result
# ---------------------------------------------------------------------------
def test_handle_pipeline_result_allow_with_modifications():
data = {"a": 1}
result = MagicMock()
result.terminal_action = "allow"
result.modified_data = {"b": 2, "c": 3}
out = ProxyLogging._handle_pipeline_result(result=result, data=data, policy_name="p")
assert out == {"a": 1, "b": 2, "c": 3}
def test_handle_pipeline_result_block_raises_http_exception():
result = MagicMock()
result.terminal_action = "block"
result.step_results = []
with pytest.raises(HTTPException) as info:
ProxyLogging._handle_pipeline_result(result=result, data={"model": "m"}, policy_name="p")
detail = info.value.detail
snapshot = {
"is_dict": isinstance(detail, dict),
"error_type": detail["error"]["type"],
"policy": detail["error"]["pipeline_context"]["policy"],
}
assert snapshot == {
"is_dict": True,
"error_type": "guardrail_pipeline_error",
"policy": "p",
}
def test_handle_pipeline_result_modify_response_raises_modify_exception():
result = MagicMock()
result.terminal_action = "modify_response"
result.modify_response_message = "filtered"
with pytest.raises(ModifyResponseException):
ProxyLogging._handle_pipeline_result(result=result, data={"model": "m"}, policy_name="p")
def test_handle_pipeline_result_unknown_action_returns_data():
data = {"a": 1, "b": 2, "c": 3}
result = MagicMock()
result.terminal_action = "something_else"
assert ProxyLogging._handle_pipeline_result(result=result, data=data, policy_name="p") is data
# ---------------------------------------------------------------------------
# _run_guardrail_task_with_enrichment
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_guardrail_task_with_enrichment_passes_result():
async def task():
return {"a": 1, "b": 2, "c": 3}
out = await ProxyLogging._run_guardrail_task_with_enrichment(
callback=MagicMock(guardrail_name="g"), coro=task()
)
assert out == {"a": 1, "b": 2, "c": 3}
@pytest.mark.asyncio
async def test_run_guardrail_task_with_enrichment_enriches_http_exception_raises():
detail = {"error": "blocked"}
async def task():
raise HTTPException(status_code=400, detail=detail)
cb = MagicMock()
cb.guardrail_name = "presidio"
cb.event_hook = "pre_call"
with pytest.raises(HTTPException):
await ProxyLogging._run_guardrail_task_with_enrichment(callback=cb, coro=task())
assert detail["guardrail_name"] == "presidio"
# ---------------------------------------------------------------------------
# _process_prompt_template
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_process_prompt_template_no_op_when_no_prompt_spec(proxy_logging, monkeypatch):
from litellm.proxy.prompts import prompt_registry
monkeypatch.setattr(
prompt_registry.IN_MEMORY_PROMPT_REGISTRY, "get_prompt_callback_by_id", lambda *a, **kw: None
)
monkeypatch.setattr(
prompt_registry.IN_MEMORY_PROMPT_REGISTRY, "get_prompt_by_id", lambda *a, **kw: None
)
data: Dict[str, Any] = {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
await proxy_logging._process_prompt_template(
data=data,
litellm_logging_obj=MagicMock(),
prompt_id="some-id",
prompt_version=1,
call_type="completion",
)
assert data == {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
@pytest.mark.asyncio
async def test_process_prompt_template_applies_when_spec_resolves(proxy_logging, monkeypatch):
from litellm.proxy.prompts import prompt_registry
custom_logger = MagicMock()
prompt_spec = MagicMock()
prompt_spec.litellm_params = MagicMock(prompt_id="resolved-id")
monkeypatch.setattr(
prompt_registry.IN_MEMORY_PROMPT_REGISTRY,
"get_prompt_callback_by_id",
lambda *a, **kw: custom_logger,
)
monkeypatch.setattr(
prompt_registry.IN_MEMORY_PROMPT_REGISTRY, "get_prompt_by_id", lambda *a, **kw: prompt_spec
)
logging_obj = MagicMock()
logging_obj.async_get_chat_completion_prompt = AsyncMock(
return_value=(
"model-out",
[{"role": "user", "content": "rendered"}],
{"temperature": 0.5, "top_p": 1},
)
)
data: Dict[str, Any] = {
"messages": [{"role": "user", "content": "orig"}],
"model": "m",
"prompt_id": "x",
}
await proxy_logging._process_prompt_template(
data=data,
litellm_logging_obj=logging_obj,
prompt_id="x",
prompt_version=None,
call_type="completion",
)
snapshot = {
"model": data["model"],
"messages": data["messages"],
"temperature": data["temperature"],
"top_p": data["top_p"],
}
assert snapshot == {
"model": "model-out",
"messages": [{"role": "user", "content": "rendered"}],
"temperature": 0.5,
"top_p": 1,
}
@pytest.mark.asyncio
async def test_process_prompt_template_async_get_prompt_error_raises(proxy_logging, monkeypatch):
from litellm.proxy.prompts import prompt_registry
custom_logger = MagicMock()
prompt_spec = MagicMock()
prompt_spec.litellm_params = MagicMock(prompt_id="x")
monkeypatch.setattr(
prompt_registry.IN_MEMORY_PROMPT_REGISTRY,
"get_prompt_callback_by_id",
lambda *a, **kw: custom_logger,
)
monkeypatch.setattr(
prompt_registry.IN_MEMORY_PROMPT_REGISTRY, "get_prompt_by_id", lambda *a, **kw: prompt_spec
)
logging_obj = MagicMock()
logging_obj.async_get_chat_completion_prompt = AsyncMock(side_effect=RuntimeError("bad prompt"))
with pytest.raises(RuntimeError):
await proxy_logging._process_prompt_template(
data={"messages": [], "model": "m", "prompt_id": "x"},
litellm_logging_obj=logging_obj,
prompt_id="x",
prompt_version=None,
call_type="completion",
)

View File

@ -0,0 +1,186 @@
"""Pin behavior of ``InternalUsageCache``: a thin adapter over ``DualCache``.
Each method should pass-through to the underlying ``DualCache`` with
exactly the same arguments, mapping ``litellm_parent_otel_span`` to the
``DualCache`` kw it expects.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from litellm.caching.caching import DualCache
from litellm.proxy.utils import InternalUsageCache
def _kwargs_snapshot(call):
return dict(call.kwargs)
def test_internal_usage_cache_init_stores_dual_cache():
inner = DualCache(default_in_memory_ttl=1)
cache = InternalUsageCache(dual_cache=inner)
snapshot = {
"is_internal_usage_cache": isinstance(cache, InternalUsageCache),
"dual_cache_is_inner": cache.dual_cache is inner,
"ttl_is_one": inner.default_in_memory_ttl == 1,
}
assert snapshot == {
"is_internal_usage_cache": True,
"dual_cache_is_inner": True,
"ttl_is_one": True,
}
def test_internal_usage_cache_init_error_requires_dual_cache():
with pytest.raises(TypeError):
InternalUsageCache() # type: ignore[call-arg]
@pytest.mark.asyncio
async def test_async_get_cache_forwards_args():
inner = MagicMock()
inner.async_get_cache = AsyncMock(return_value={"hit": True, "value": 42, "source": "redis"})
cache = InternalUsageCache(dual_cache=inner)
result = await cache.async_get_cache(key="k", litellm_parent_otel_span="span", local_only=True, extra="x")
forwarded = _kwargs_snapshot(inner.async_get_cache.call_args)
assert forwarded == {"key": "k", "local_only": True, "parent_otel_span": "span", "extra": "x"}
assert result == {"hit": True, "value": 42, "source": "redis"}
@pytest.mark.asyncio
async def test_async_get_cache_propagates_underlying_error_raises():
inner = MagicMock()
inner.async_get_cache = AsyncMock(side_effect=RuntimeError("redis down"))
cache = InternalUsageCache(dual_cache=inner)
with pytest.raises(RuntimeError, match="redis down"):
await cache.async_get_cache(key="k", litellm_parent_otel_span=None)
@pytest.mark.asyncio
async def test_async_set_cache_forwards_args():
inner = MagicMock()
inner.async_set_cache = AsyncMock()
cache = InternalUsageCache(dual_cache=inner)
await cache.async_set_cache(key="k", value="v", litellm_parent_otel_span="span", local_only=False, ttl=60)
forwarded = _kwargs_snapshot(inner.async_set_cache.call_args)
assert forwarded == {
"key": "k",
"value": "v",
"local_only": False,
"litellm_parent_otel_span": "span",
"ttl": 60,
}
@pytest.mark.asyncio
async def test_async_set_cache_propagates_error_raises():
inner = MagicMock()
inner.async_set_cache = AsyncMock(side_effect=ValueError("bad value"))
cache = InternalUsageCache(dual_cache=inner)
with pytest.raises(ValueError, match="bad value"):
await cache.async_set_cache(key="k", value="v", litellm_parent_otel_span=None)
@pytest.mark.asyncio
async def test_async_batch_set_cache_forwards_pipeline():
inner = MagicMock()
inner.async_set_cache_pipeline = AsyncMock()
cache = InternalUsageCache(dual_cache=inner)
pairs = [("a", 1), ("b", 2)]
await cache.async_batch_set_cache(cache_list=pairs, litellm_parent_otel_span=None, local_only=True, ttl=10)
forwarded = _kwargs_snapshot(inner.async_set_cache_pipeline.call_args)
assert forwarded == {
"cache_list": pairs,
"local_only": True,
"litellm_parent_otel_span": None,
"ttl": 10,
}
@pytest.mark.asyncio
async def test_async_batch_set_cache_propagates_error_raises():
inner = MagicMock()
inner.async_set_cache_pipeline = AsyncMock(side_effect=ConnectionError("network"))
cache = InternalUsageCache(dual_cache=inner)
with pytest.raises(ConnectionError):
await cache.async_batch_set_cache(cache_list=[], litellm_parent_otel_span=None)
@pytest.mark.asyncio
async def test_async_batch_get_cache_forwards_args():
inner = MagicMock()
inner.async_batch_get_cache = AsyncMock(return_value=[1, 2, 3])
cache = InternalUsageCache(dual_cache=inner)
result = await cache.async_batch_get_cache(keys=["a", "b", "c"], parent_otel_span="span", local_only=False)
forwarded = _kwargs_snapshot(inner.async_batch_get_cache.call_args)
assert forwarded == {"keys": ["a", "b", "c"], "parent_otel_span": "span", "local_only": False}
assert result == [1, 2, 3]
@pytest.mark.asyncio
async def test_async_batch_get_cache_invalid_input_raises():
inner = MagicMock()
inner.async_batch_get_cache = AsyncMock(side_effect=TypeError("not a list"))
cache = InternalUsageCache(dual_cache=inner)
with pytest.raises(TypeError):
await cache.async_batch_get_cache(keys=None) # type: ignore[arg-type]
@pytest.mark.asyncio
async def test_async_increment_cache_forwards_args():
inner = MagicMock()
inner.async_increment_cache = AsyncMock(return_value=5.0)
cache = InternalUsageCache(dual_cache=inner)
result = await cache.async_increment_cache(key="counter", value=1.5, litellm_parent_otel_span="span")
forwarded = _kwargs_snapshot(inner.async_increment_cache.call_args)
assert forwarded == {"key": "counter", "value": 1.5, "local_only": False, "parent_otel_span": "span"}
assert result == 5.0
@pytest.mark.asyncio
async def test_async_increment_cache_propagates_error_raises():
inner = MagicMock()
inner.async_increment_cache = AsyncMock(side_effect=OverflowError())
cache = InternalUsageCache(dual_cache=inner)
with pytest.raises(OverflowError):
await cache.async_increment_cache(key="x", value=1.0, litellm_parent_otel_span=None)
def test_set_cache_forwards_args():
inner = MagicMock()
cache = InternalUsageCache(dual_cache=inner)
cache.set_cache(key="k", value="v", local_only=True, ttl=30)
forwarded = _kwargs_snapshot(inner.set_cache.call_args)
assert forwarded == {"key": "k", "value": "v", "local_only": True, "ttl": 30}
def test_set_cache_propagates_error_raises():
inner = MagicMock()
inner.set_cache = MagicMock(side_effect=RuntimeError("no redis"))
cache = InternalUsageCache(dual_cache=inner)
with pytest.raises(RuntimeError):
cache.set_cache(key="k", value="v")
def test_get_cache_forwards_args_and_returns_inner_result():
inner = MagicMock()
inner.get_cache = MagicMock(return_value={"k": "v", "ttl": 60, "source": "mem"})
cache = InternalUsageCache(dual_cache=inner)
result = cache.get_cache(key="k", local_only=False)
forwarded = _kwargs_snapshot(inner.get_cache.call_args)
assert forwarded == {"key": "k", "local_only": False}
assert result == {"k": "v", "ttl": 60, "source": "mem"}
def test_get_cache_propagates_error_raises():
inner = MagicMock()
inner.get_cache = MagicMock(side_effect=KeyError("missing"))
cache = InternalUsageCache(dual_cache=inner)
with pytest.raises(KeyError):
cache.get_cache(key="missing")

View File

@ -0,0 +1,403 @@
"""Pin ProxyLogging lifecycle: ``__init__``, ``startup_event``,
``update_values``, ``_add_proxy_hooks``, ``get_proxy_hook``, and
``_init_litellm_callbacks``.
Also covers ``update_request_status`` and ``_convert_user_api_key_auth_to_dict``
because they are direct dependents on the lifecycle state.
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.utils import (
InternalUsageCache,
ProxyLogging,
)
# ---------------------------------------------------------------------------
# __init__
# ---------------------------------------------------------------------------
def test_proxy_logging_init_sets_default_state(mock_callbacks_disabled):
cache = UserApiKeyCache()
pl = ProxyLogging(user_api_key_cache=cache)
snapshot = {
"internal_usage_cache_type": type(pl.internal_usage_cache).__name__,
"alerting_is_none": pl.alerting is None,
"alerting_threshold": pl.alerting_threshold,
"premium_user": pl.premium_user,
"proxy_hook_mapping": pl.proxy_hook_mapping,
"daily_report_started": pl.daily_report_started,
"hanging_requests_check_started": pl.hanging_requests_check_started,
}
assert snapshot == {
"internal_usage_cache_type": "InternalUsageCache",
"alerting_is_none": True,
"alerting_threshold": 300,
"premium_user": False,
"proxy_hook_mapping": {},
"daily_report_started": False,
"hanging_requests_check_started": False,
}
def test_proxy_logging_init_premium_user_flag(mock_callbacks_disabled):
pl = ProxyLogging(user_api_key_cache=UserApiKeyCache(), premium_user=True)
assert pl.premium_user is True
def test_proxy_logging_init_missing_cache_raises():
with pytest.raises(TypeError):
ProxyLogging() # type: ignore[call-arg]
# ---------------------------------------------------------------------------
# update_values
# ---------------------------------------------------------------------------
def test_update_values_stores_alerting_state(proxy_logging):
proxy_logging.slack_alerting_instance = MagicMock()
proxy_logging.update_values(
alerting=["slack"],
alerting_threshold=42.0,
alert_types=["llm_too_slow"],
alert_to_webhook_url={"key": "value"},
)
snapshot = {
"alerting": proxy_logging.alerting,
"threshold": proxy_logging.alerting_threshold,
"alert_types": proxy_logging.alert_types,
"webhook_url": proxy_logging.alert_to_webhook_url,
}
assert snapshot == {
"alerting": ["slack"],
"threshold": 42.0,
"alert_types": ["llm_too_slow"],
"webhook_url": {"key": "value"},
}
def test_update_values_with_only_redis_cache_does_not_touch_slack(proxy_logging):
proxy_logging.slack_alerting_instance = MagicMock()
redis = MagicMock()
proxy_logging.update_values(redis_cache=redis)
proxy_logging.slack_alerting_instance.update_values.assert_not_called()
assert proxy_logging.internal_usage_cache.dual_cache.redis_cache is redis
def test_update_values_with_no_args_is_noop(proxy_logging):
proxy_logging.slack_alerting_instance = MagicMock()
proxy_logging.update_values()
proxy_logging.slack_alerting_instance.update_values.assert_not_called()
def test_update_values_invalid_type_for_alerting_raises(proxy_logging):
proxy_logging.slack_alerting_instance = MagicMock(
update_values=MagicMock(side_effect=TypeError("bad type"))
)
with pytest.raises(TypeError):
proxy_logging.update_values(alerting={"not": "a list"}) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# startup_event
# ---------------------------------------------------------------------------
def test_startup_event_initializes_slack_and_callbacks(proxy_logging):
proxy_logging.slack_alerting_instance = MagicMock()
proxy_logging.slack_alerting_instance.alert_types = []
proxy_logging._init_litellm_callbacks = MagicMock()
proxy_logging.update_values = MagicMock()
proxy_logging.startup_event(llm_router=None, redis_usage_cache=None)
snapshot = {
"update_called": proxy_logging.update_values.called,
"init_called": proxy_logging._init_litellm_callbacks.called,
"slack_update_called": proxy_logging.slack_alerting_instance.update_values.called,
}
assert snapshot == {
"update_called": True,
"init_called": True,
"slack_update_called": True,
}
def test_startup_event_propagates_init_callbacks_failure_raises(proxy_logging):
proxy_logging.slack_alerting_instance = MagicMock()
proxy_logging.slack_alerting_instance.alert_types = []
proxy_logging._init_litellm_callbacks = MagicMock(side_effect=RuntimeError("boom"))
with pytest.raises(RuntimeError, match="boom"):
proxy_logging.startup_event(llm_router=None, redis_usage_cache=None)
# ---------------------------------------------------------------------------
# _add_proxy_hooks
# ---------------------------------------------------------------------------
def test_add_proxy_hooks_registers_callbacks(proxy_logging, monkeypatch):
"""Patch ``PROXY_HOOKS`` and the resolver so we control exactly
what gets registered. Verifies that the resulting instances land in
``proxy_logging.proxy_hook_mapping`` keyed by hook name.
"""
hook_keys = ["cache_control_check", "max_budget_limiter"]
registered: List[Any] = []
from litellm.proxy import utils as utils_mod
def fake_get_proxy_hook(hook_name):
class _Stub:
__name__ = hook_name
def __init__(self, **kwargs):
self.hook_name = hook_name
return _Stub
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", hook_keys)
monkeypatch.setattr(utils_mod, "get_proxy_hook", fake_get_proxy_hook)
monkeypatch.setattr(
litellm.logging_callback_manager,
"add_litellm_callback",
lambda cb: registered.append(cb),
)
with patch("litellm.proxy.proxy_server.prisma_client", None):
proxy_logging._add_proxy_hooks(llm_router=None)
keys = list(proxy_logging.proxy_hook_mapping.keys())
snapshot = {
"mapping_keys": keys,
"registered_count": len(registered),
"registered_hook_names": [getattr(r, "hook_name", None) for r in registered],
}
assert snapshot == {
"mapping_keys": hook_keys,
"registered_count": len(hook_keys),
"registered_hook_names": hook_keys,
}
def test_add_proxy_hooks_unknown_hook_raises(proxy_logging, monkeypatch):
from litellm.proxy import utils as utils_mod
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", ["bogus_hook"])
def bad_resolver(name):
raise KeyError(name)
monkeypatch.setattr(utils_mod, "get_proxy_hook", bad_resolver)
with pytest.raises(KeyError):
proxy_logging._add_proxy_hooks(llm_router=None)
# ---------------------------------------------------------------------------
# get_proxy_hook
# ---------------------------------------------------------------------------
def test_get_proxy_hook_returns_registered_instance(proxy_logging):
s_cache = MagicMock()
s_budget = MagicMock()
s_parallel = MagicMock()
proxy_logging.proxy_hook_mapping = {
"cache_control_check": s_cache,
"max_budget_limiter": s_budget,
"max_parallel_request_limiter": s_parallel,
}
snapshot = {
"cache_control_check": proxy_logging.get_proxy_hook("cache_control_check") is s_cache,
"max_budget_limiter": proxy_logging.get_proxy_hook("max_budget_limiter") is s_budget,
"max_parallel_request_limiter": proxy_logging.get_proxy_hook("max_parallel_request_limiter") is s_parallel,
"unknown_returns_none": proxy_logging.get_proxy_hook("unknown") is None,
}
assert snapshot == {
"cache_control_check": True,
"max_budget_limiter": True,
"max_parallel_request_limiter": True,
"unknown_returns_none": True,
}
def test_get_proxy_hook_unknown_returns_none(proxy_logging):
proxy_logging.proxy_hook_mapping = {}
assert proxy_logging.get_proxy_hook("does-not-exist") is None
def test_get_proxy_hook_non_string_key_raises(proxy_logging):
# ``dict.get`` doesn't raise on unhashable types — but ``None`` returns None.
# The pin: passing an unhashable key blows up like dict access does.
proxy_logging.proxy_hook_mapping = {"k": object()}
with pytest.raises(TypeError):
proxy_logging.get_proxy_hook({"unhashable": True}) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# _init_litellm_callbacks
# ---------------------------------------------------------------------------
def test_init_litellm_callbacks_replaces_string_with_instance(proxy_logging, monkeypatch):
from litellm.proxy import utils as utils_mod
sentinel_instance = MagicMock(spec=litellm.integrations.custom_logger.CustomLogger)
sentinel_instance.__class__ = litellm.integrations.custom_logger.CustomLogger
monkeypatch.setattr(litellm, "callbacks", ["some-string-logger"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"_init_custom_logger_compatible_class",
lambda *a, **kw: sentinel_instance,
)
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", [])
proxy_logging._init_litellm_callbacks(llm_router=None)
snapshot = {
"replaced_first_item": litellm.callbacks[0] is sentinel_instance,
"callbacks_grew_with_service": len(litellm.callbacks) >= 2,
"service_logging_appended": any(
"ServiceLogging" in type(c).__name__ for c in litellm.callbacks
),
}
assert snapshot == {
"replaced_first_item": True,
"callbacks_grew_with_service": True,
"service_logging_appended": True,
}
def test_init_litellm_callbacks_string_resolution_failure_keeps_string(proxy_logging, monkeypatch):
from litellm.proxy import utils as utils_mod
monkeypatch.setattr(litellm, "callbacks", ["unknown-logger"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"_init_custom_logger_compatible_class",
lambda *a, **kw: None,
)
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", [])
proxy_logging._init_litellm_callbacks(llm_router=None)
# Resolver returned None — original string remains in place at idx 0.
assert litellm.callbacks[0] == "unknown-logger"
def test_init_litellm_callbacks_propagates_resolver_error_raises(proxy_logging, monkeypatch):
from litellm.proxy import utils as utils_mod
monkeypatch.setattr(litellm, "callbacks", ["raises-on-init"])
monkeypatch.setattr(
litellm.litellm_core_utils.litellm_logging,
"_init_custom_logger_compatible_class",
MagicMock(side_effect=RuntimeError("bad init")),
)
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", [])
with pytest.raises(RuntimeError):
proxy_logging._init_litellm_callbacks(llm_router=None)
# ---------------------------------------------------------------------------
# update_request_status
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_request_status_when_alerting_set_writes_cache(proxy_logging):
proxy_logging.alerting = ["slack"]
proxy_logging.alerting_threshold = 5.0
captured: Dict[str, Any] = {}
async def fake_set_cache(**kwargs):
captured.update(kwargs)
proxy_logging.internal_usage_cache.async_set_cache = fake_set_cache # type: ignore[assignment]
await proxy_logging.update_request_status(litellm_call_id="call-1", status="success")
snapshot = {
"key": captured["key"],
"value": captured["value"],
"local_only": captured["local_only"],
"ttl": captured["ttl"],
}
assert snapshot == {
"key": "request_status:call-1",
"value": "success",
"local_only": True,
"ttl": 105.0,
}
@pytest.mark.asyncio
async def test_update_request_status_no_alerting_skips_cache(proxy_logging):
proxy_logging.alerting = None
proxy_logging.internal_usage_cache.async_set_cache = AsyncMock()
await proxy_logging.update_request_status(litellm_call_id="call-1", status="success")
proxy_logging.internal_usage_cache.async_set_cache.assert_not_called()
@pytest.mark.asyncio
async def test_update_request_status_cache_error_raises(proxy_logging):
proxy_logging.alerting = ["slack"]
proxy_logging.internal_usage_cache.async_set_cache = AsyncMock(side_effect=ConnectionError("redis"))
with pytest.raises(ConnectionError):
await proxy_logging.update_request_status(litellm_call_id="x", status="fail")
# ---------------------------------------------------------------------------
# _convert_user_api_key_auth_to_dict
# ---------------------------------------------------------------------------
def test_convert_user_api_key_auth_to_dict_pydantic_uses_model_dump(proxy_logging, make_user_api_key_auth):
auth = make_user_api_key_auth(user_id="u-1", team_id="t-1")
result = proxy_logging._convert_user_api_key_auth_to_dict(auth)
snapshot = {
"user_id": result["user_id"],
"team_id": result["team_id"],
"is_dict": isinstance(result, dict),
}
assert snapshot == {"user_id": "u-1", "team_id": "t-1", "is_dict": True}
def test_convert_user_api_key_auth_to_dict_plain_object_uses_dict(proxy_logging):
class Obj:
pass
obj = Obj()
obj.a = 1
obj.b = 2
obj.c = 3
result = proxy_logging._convert_user_api_key_auth_to_dict(obj)
assert result == {"a": 1, "b": 2, "c": 3}
def test_convert_user_api_key_auth_to_dict_none_returns_empty_dict(proxy_logging):
assert proxy_logging._convert_user_api_key_auth_to_dict(None) == {}
def test_convert_user_api_key_auth_to_dict_unconvertible_object_returns_empty(proxy_logging):
class NoDict:
__slots__ = ()
assert proxy_logging._convert_user_api_key_auth_to_dict(NoDict()) == {}
def test_convert_user_api_key_auth_to_dict_pydantic_error_raises(proxy_logging):
"""A ``model_dump`` that raises propagates."""
class _Boom:
def model_dump(self):
raise RuntimeError("model_dump failure")
with pytest.raises(RuntimeError):
proxy_logging._convert_user_api_key_auth_to_dict(_Boom())

View File

@ -0,0 +1,426 @@
"""Pin ProxyLogging's MCP-LLM bridging helpers.
Covers:
- ``_convert_mcp_to_llm_format``
- ``_convert_llm_result_to_mcp_response``
- ``_extract_modified_arguments_from_content``
- ``_parse_arguments_manually``
- ``_convert_llm_result_to_mcp_during_response``
- ``_parse_pre_mcp_call_hook_response``
- ``_create_mcp_request_object_from_kwargs``
- ``_convert_mcp_hook_response_to_kwargs``
"""
from __future__ import annotations
import pytest
from litellm.types.mcp import (
MCPDuringCallResponseObject,
MCPPreCallRequestObject,
MCPPreCallResponseObject,
)
# ---------------------------------------------------------------------------
# _convert_mcp_to_llm_format
# ---------------------------------------------------------------------------
def test_convert_mcp_to_llm_format_returns_synthetic_data(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(tool_name="search", arguments={"q": "hello"})
out = proxy_logging._convert_mcp_to_llm_format(
request_obj=req,
kwargs={
"model": "gpt-4o-mini",
"user_api_key_user_id": "u-1",
"user_api_key_team_id": "t-1",
"user_api_key_end_user_id": "eu-1",
"user_api_key_hash": "hash",
"user_api_key_request_route": "/mcp",
"incoming_bearer_token": "tok",
},
)
snapshot = {
"model": out["model"],
"user_id": out["user_api_key_user_id"],
"mcp_tool_name": out["mcp_tool_name"],
"mcp_arguments": out["mcp_arguments"],
"incoming_bearer_token": out["incoming_bearer_token"],
"message_role": out["messages"][0]["role"],
}
assert snapshot == {
"model": "gpt-4o-mini",
"user_id": "u-1",
"mcp_tool_name": "search",
"mcp_arguments": {"q": "hello"},
"incoming_bearer_token": "tok",
"message_role": "user",
}
def test_convert_mcp_to_llm_format_defaults_model(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj()
out = proxy_logging._convert_mcp_to_llm_format(request_obj=req, kwargs={})
snapshot = {
"model": out["model"],
"mcp_tool_name": out["mcp_tool_name"],
"incoming_bearer_token": out["incoming_bearer_token"],
"user_id": out["user_api_key_user_id"],
}
assert snapshot == {
"model": "mcp-tool-call",
"mcp_tool_name": "calculator",
"incoming_bearer_token": None,
"user_id": None,
}
def test_convert_mcp_to_llm_format_missing_request_obj_raises(proxy_logging):
with pytest.raises(AttributeError):
proxy_logging._convert_mcp_to_llm_format(request_obj=None, kwargs={})
# ---------------------------------------------------------------------------
# _convert_llm_result_to_mcp_response
# ---------------------------------------------------------------------------
def test_convert_llm_result_to_mcp_response_exception_blocks(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj()
result = proxy_logging._convert_llm_result_to_mcp_response(
llm_result=ValueError("boom"),
request_obj=req,
)
assert isinstance(result, MCPPreCallResponseObject)
snapshot = {
"should_proceed": result.should_proceed,
"error_message": result.error_message,
"modified_arguments": result.modified_arguments,
}
assert snapshot == {"should_proceed": False, "error_message": "boom", "modified_arguments": None}
def test_convert_llm_result_to_mcp_response_blocked_content(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(tool_name="t", arguments={"a": 1})
llm_result = {"messages": [{"content": "this is blocked"}]}
result = proxy_logging._convert_llm_result_to_mcp_response(llm_result=llm_result, request_obj=req)
assert isinstance(result, MCPPreCallResponseObject)
assert result.should_proceed is False
assert "blocked" in (result.error_message or "").lower()
def test_convert_llm_result_to_mcp_response_modified_content_redacted(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(tool_name="search", arguments={"q": "ssn 123"})
llm_result = {"messages": [{"content": "Tool: search\nArguments: {\"q\": \"[REDACTED]\"}"}]}
result = proxy_logging._convert_llm_result_to_mcp_response(llm_result=llm_result, request_obj=req)
assert isinstance(result, MCPPreCallResponseObject)
snapshot = {
"should_proceed": result.should_proceed,
"modified_q": (result.modified_arguments or {}).get("q"),
"error": result.error_message,
}
assert snapshot == {"should_proceed": True, "modified_q": "[REDACTED]", "error": None}
def test_convert_llm_result_to_mcp_response_string_blocks(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj()
result = proxy_logging._convert_llm_result_to_mcp_response(llm_result="bad input", request_obj=req)
assert isinstance(result, MCPPreCallResponseObject)
snapshot = {
"should_proceed": result.should_proceed,
"error_message": result.error_message,
"modified_arguments": result.modified_arguments,
}
assert snapshot == {"should_proceed": False, "error_message": "bad input", "modified_arguments": None}
def test_convert_llm_result_to_mcp_response_unmodified_returns_none(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(tool_name="x", arguments={"a": 1})
same_content = "Tool: x\nArguments: {'a': 1}"
result = proxy_logging._convert_llm_result_to_mcp_response(
llm_result={"messages": [{"content": same_content}]},
request_obj=req,
)
assert result is None
def test_convert_llm_result_to_mcp_response_no_request_obj_raises(proxy_logging):
with pytest.raises(AttributeError):
proxy_logging._convert_llm_result_to_mcp_response(llm_result={"messages": [{"content": "x"}]}, request_obj=None)
# ---------------------------------------------------------------------------
# _extract_modified_arguments_from_content
# ---------------------------------------------------------------------------
def test_extract_modified_arguments_from_content_parses_json(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj()
out = proxy_logging._extract_modified_arguments_from_content(
masked_content="Tool: x\nArguments: {\"a\": 1, \"b\": 2, \"c\": 3}",
request_obj=req,
)
assert out == {"a": 1, "b": 2, "c": 3}
def test_extract_modified_arguments_from_content_no_arguments_line_returns_none(proxy_logging, make_mcp_request_obj):
out = proxy_logging._extract_modified_arguments_from_content(
masked_content="random content with no arguments",
request_obj=make_mcp_request_obj(),
)
assert out is None
def test_extract_modified_arguments_from_content_empty_string_returns_none(proxy_logging, make_mcp_request_obj):
out = proxy_logging._extract_modified_arguments_from_content(
masked_content="",
request_obj=make_mcp_request_obj(),
)
assert out is None
def test_extract_modified_arguments_from_content_invalid_json_falls_back(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(arguments={"name": "alice"})
out = proxy_logging._extract_modified_arguments_from_content(
masked_content="Tool: x\nArguments: {name: REDACTED}",
request_obj=req,
)
assert isinstance(out, dict)
assert "name" in out
def test_extract_modified_arguments_from_content_error_swallowed_returns_none(proxy_logging):
"""Internal try/except swallows any unexpected error and returns None."""
out = proxy_logging._extract_modified_arguments_from_content(masked_content=None, request_obj=None)
assert out is None
# ---------------------------------------------------------------------------
# _parse_arguments_manually
# ---------------------------------------------------------------------------
def test_parse_arguments_manually_applies_overrides(proxy_logging):
original = {"name": "alice", "ssn": "123-45-6789"}
out = proxy_logging._parse_arguments_manually(
args_text='"name": "[REDACTED]", "ssn": "[REDACTED]"',
original_args=original,
)
snapshot = {"name": out["name"], "ssn": out["ssn"], "original_unchanged": original["name"]}
assert snapshot == {"name": "[REDACTED]", "ssn": "[REDACTED]", "original_unchanged": "alice"}
def test_parse_arguments_manually_returns_original_if_no_match(proxy_logging):
original = {"foo": "bar"}
out = proxy_logging._parse_arguments_manually(args_text="nothing here", original_args=original)
assert out == {"foo": "bar"}
def test_parse_arguments_manually_error_swallowed_returns_none(proxy_logging):
# Defensive: function catches any exception internally and returns None.
assert proxy_logging._parse_arguments_manually(args_text="x", original_args=None) is None # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# _convert_llm_result_to_mcp_during_response
# ---------------------------------------------------------------------------
def test_convert_llm_result_to_mcp_during_response_exception(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj()
result = proxy_logging._convert_llm_result_to_mcp_during_response(
llm_result=ValueError("during boom"), request_obj=req
)
assert isinstance(result, MCPDuringCallResponseObject)
snapshot = {
"should_continue": result.should_continue,
"error_message": result.error_message,
"type": type(result).__name__,
}
assert snapshot == {
"should_continue": False,
"error_message": "during boom",
"type": "MCPDuringCallResponseObject",
}
def test_convert_llm_result_to_mcp_during_response_blocked_content(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(tool_name="t", arguments={"a": 1})
result = proxy_logging._convert_llm_result_to_mcp_during_response(
llm_result={"messages": [{"content": "blocked content"}]},
request_obj=req,
)
assert isinstance(result, MCPDuringCallResponseObject)
assert result.should_continue is False
assert "blocked" in (result.error_message or "").lower()
def test_convert_llm_result_to_mcp_during_response_modified_stops(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(tool_name="t", arguments={"a": 1})
result = proxy_logging._convert_llm_result_to_mcp_during_response(
llm_result={"messages": [{"content": "Tool: t\nArguments: {\"a\": \"[REDACTED]\"}"}]},
request_obj=req,
)
assert isinstance(result, MCPDuringCallResponseObject)
assert result.should_continue is False
assert "modified" in (result.error_message or "").lower()
def test_convert_llm_result_to_mcp_during_response_string_blocks(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj()
result = proxy_logging._convert_llm_result_to_mcp_during_response(
llm_result="kill switch", request_obj=req
)
assert isinstance(result, MCPDuringCallResponseObject)
snapshot = {"should_continue": result.should_continue, "error_message": result.error_message}
assert snapshot == {"should_continue": False, "error_message": "kill switch"}
def test_convert_llm_result_to_mcp_during_response_unmodified_returns_none(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(tool_name="t", arguments={"a": 1})
same = "Tool: t\nArguments: {'a': 1}"
assert (
proxy_logging._convert_llm_result_to_mcp_during_response(
llm_result={"messages": [{"content": same}]},
request_obj=req,
)
is None
)
def test_convert_llm_result_to_mcp_during_response_no_request_obj_raises(proxy_logging):
with pytest.raises(AttributeError):
proxy_logging._convert_llm_result_to_mcp_during_response(
llm_result={"messages": [{"content": "x"}]}, request_obj=None
)
# ---------------------------------------------------------------------------
# _parse_pre_mcp_call_hook_response
# ---------------------------------------------------------------------------
def test_parse_pre_mcp_call_hook_response_with_modified_args(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(arguments={"a": 1})
resp = MCPPreCallResponseObject(
should_proceed=True,
modified_arguments={"a": "x", "b": "y"},
error_message=None,
)
out = proxy_logging._parse_pre_mcp_call_hook_response(response=resp, original_request=req)
snapshot = {
"should_proceed": out["should_proceed"],
"modified_arguments": out["modified_arguments"],
"error_message": out["error_message"],
"hidden_params_type": type(out["hidden_params"]).__name__,
}
assert snapshot == {
"should_proceed": True,
"modified_arguments": {"a": "x", "b": "y"},
"error_message": None,
"hidden_params_type": "HiddenParams",
}
def test_parse_pre_mcp_call_hook_response_no_modifications_uses_original(proxy_logging, make_mcp_request_obj):
req = make_mcp_request_obj(arguments={"original": True})
resp = MCPPreCallResponseObject(
should_proceed=True, modified_arguments=None, error_message=None
)
out = proxy_logging._parse_pre_mcp_call_hook_response(response=resp, original_request=req)
assert out["modified_arguments"] == {"original": True}
def test_parse_pre_mcp_call_hook_response_invalid_response_raises(proxy_logging, make_mcp_request_obj):
with pytest.raises(AttributeError):
proxy_logging._parse_pre_mcp_call_hook_response(
response=None, original_request=make_mcp_request_obj()
)
# ---------------------------------------------------------------------------
# _create_mcp_request_object_from_kwargs
# ---------------------------------------------------------------------------
def test_create_mcp_request_object_from_kwargs_full(proxy_logging, make_user_api_key_auth):
auth = make_user_api_key_auth(user_id="u-1")
obj = proxy_logging._create_mcp_request_object_from_kwargs(
kwargs={
"name": "calc",
"arguments": {"x": 1},
"server_name": "math",
"user_api_key_auth": auth,
}
)
assert isinstance(obj, MCPPreCallRequestObject)
snapshot = {
"tool_name": obj.tool_name,
"arguments": obj.arguments,
"server_name": obj.server_name,
"auth_user_id": obj.user_api_key_auth.get("user_id"),
}
assert snapshot == {"tool_name": "calc", "arguments": {"x": 1}, "server_name": "math", "auth_user_id": "u-1"}
def test_create_mcp_request_object_from_kwargs_empty(proxy_logging):
obj = proxy_logging._create_mcp_request_object_from_kwargs(kwargs={})
snapshot = {
"tool_name": obj.tool_name,
"arguments": obj.arguments,
"server_name": obj.server_name,
}
assert snapshot == {"tool_name": "", "arguments": {}, "server_name": None}
def test_create_mcp_request_object_from_kwargs_non_dict_raises(proxy_logging):
with pytest.raises(AttributeError):
proxy_logging._create_mcp_request_object_from_kwargs(kwargs=None) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# _convert_mcp_hook_response_to_kwargs
# ---------------------------------------------------------------------------
def test_convert_mcp_hook_response_to_kwargs_applies_modified_args(proxy_logging):
original = {"arguments": {"a": 1}, "name": "old"}
out = proxy_logging._convert_mcp_hook_response_to_kwargs(
response_data={"modified_arguments": {"a": 2}, "extra_headers": {"H": "1"}},
original_kwargs=original,
)
snapshot = {
"arguments": out["arguments"],
"extra_headers": out["extra_headers"],
"name": out["name"],
"original_unmodified": original["arguments"],
}
assert snapshot == {
"arguments": {"a": 2},
"extra_headers": {"H": "1"},
"name": "old",
"original_unmodified": {"a": 1},
}
def test_convert_mcp_hook_response_to_kwargs_merges_headers(proxy_logging):
original = {"extra_headers": {"keep": "yes", "overwrite": "old"}}
out = proxy_logging._convert_mcp_hook_response_to_kwargs(
response_data={"extra_headers": {"overwrite": "new", "added": "1"}},
original_kwargs=original,
)
assert out["extra_headers"] == {"keep": "yes", "overwrite": "new", "added": "1"}
def test_convert_mcp_hook_response_to_kwargs_no_response_data_returns_original(proxy_logging):
original = {"a": 1}
out = proxy_logging._convert_mcp_hook_response_to_kwargs(response_data=None, original_kwargs=original)
assert out is original
def test_convert_mcp_hook_response_to_kwargs_invalid_original_raises(proxy_logging):
with pytest.raises(AttributeError):
proxy_logging._convert_mcp_hook_response_to_kwargs(
response_data={"modified_arguments": {"a": 1}}, original_kwargs=None # type: ignore[arg-type]
)

View File

@ -0,0 +1,353 @@
"""Pin behavior of top-of-file and bottom-of-region helpers.
Covers ``print_verbose``, ``_get_email_logger_class``,
``_accepts_litellm_call_info``, ``_enrich_http_exception_with_guardrail_context``,
``on_backoff``, ``jsonify_object``, ``_lookup_deprecated_key``.
"""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import HTTPException
import litellm
from litellm.proxy import utils as utils_mod
from litellm.proxy.utils import (
_accepts_litellm_call_info,
_enrich_http_exception_with_guardrail_context,
_get_email_logger_class,
_lookup_deprecated_key,
jsonify_object,
on_backoff,
print_verbose,
)
# ---------------------------------------------------------------------------
# print_verbose
# ---------------------------------------------------------------------------
def test_print_verbose_when_set_verbose_true_prints_redacted(monkeypatch, capsys):
monkeypatch.setattr(litellm, "set_verbose", True)
print_verbose("hello world")
captured = capsys.readouterr()
snapshot = {
"out_has_prefix": "LiteLLM Proxy:" in captured.out,
"out_has_payload": "hello world" in captured.out,
"no_stderr": captured.err == "",
}
assert snapshot == {"out_has_prefix": True, "out_has_payload": True, "no_stderr": True}
def test_print_verbose_when_set_verbose_false_no_stdout(monkeypatch, capsys):
monkeypatch.setattr(litellm, "set_verbose", False)
print_verbose("quiet")
captured = capsys.readouterr()
assert captured.out == ""
def test_print_verbose_handles_unprintable_object_raises(monkeypatch):
monkeypatch.setattr(litellm, "set_verbose", True)
class Bomb:
def __str__(self):
raise RuntimeError("bad str")
with pytest.raises(RuntimeError):
print_verbose(Bomb())
# ---------------------------------------------------------------------------
# _get_email_logger_class
# ---------------------------------------------------------------------------
def test_get_email_logger_class_priority_matrix(monkeypatch):
"""Truth table for ``_get_email_logger_class`` priority: SendGrid >
Resend > SMTP > Base."""
sg = object()
rs = object()
smtp = object()
base = object()
monkeypatch.setattr(utils_mod, "BaseEmailLogger", base)
monkeypatch.setattr(utils_mod, "SendGridEmailLogger", sg)
monkeypatch.setattr(utils_mod, "ResendEmailLogger", rs)
monkeypatch.setattr(utils_mod, "SMTPEmailLogger", smtp)
for k in ("SENDGRID_API_KEY", "RESEND_API_KEY", "SMTP_HOST"):
monkeypatch.delenv(k, raising=False)
fallback = _get_email_logger_class() is base
monkeypatch.setenv("SMTP_HOST", "smtp.example")
smtp_choice = _get_email_logger_class() is smtp
monkeypatch.setenv("RESEND_API_KEY", "rs-x")
resend_choice = _get_email_logger_class() is rs
monkeypatch.setenv("SENDGRID_API_KEY", "sg-x")
sendgrid_choice = _get_email_logger_class() is sg
snapshot = {
"fallback_to_base": fallback,
"smtp_when_smtp_only": smtp_choice,
"resend_beats_smtp": resend_choice,
"sendgrid_wins": sendgrid_choice,
}
assert snapshot == {
"fallback_to_base": True,
"smtp_when_smtp_only": True,
"resend_beats_smtp": True,
"sendgrid_wins": True,
}
def test_get_email_logger_class_error_when_no_enterprise_module(monkeypatch):
monkeypatch.setattr(utils_mod, "BaseEmailLogger", None)
# Returns ``None`` rather than raising; this is the documented failure
# mode when the optional enterprise package is missing.
assert _get_email_logger_class() is None
# Sentinel: monkey-patch SendGrid env but keep BaseEmailLogger None;
# function still must return None and not blow up on the optional path.
monkeypatch.setenv("SENDGRID_API_KEY", "sg-x")
assert _get_email_logger_class() is None
# ---------------------------------------------------------------------------
# _accepts_litellm_call_info
# ---------------------------------------------------------------------------
class _CbAcceptsInfo:
async def async_post_call_response_headers_hook(self, *, litellm_call_info=None):
return None
class _CbRejectsInfo:
async def async_post_call_response_headers_hook(self, *, response):
return None
def test_accepts_litellm_call_info_matrix(monkeypatch):
monkeypatch.setattr(utils_mod, "_CALLBACK_ACCEPTS_CALL_INFO", {})
cache = {id(_CbAcceptsInfo): True}
monkeypatch.setattr(utils_mod, "_CALLBACK_ACCEPTS_CALL_INFO", cache)
snapshot = {
"cache_hit_returns_true": _accepts_litellm_call_info(_CbAcceptsInfo()),
"cache_size_after_hit": len(cache),
"cache_keyed_by_type_id": id(_CbAcceptsInfo) in cache,
}
assert snapshot == {
"cache_hit_returns_true": True,
"cache_size_after_hit": 1,
"cache_keyed_by_type_id": True,
}
def test_accepts_litellm_call_info_signature_inspection(monkeypatch):
monkeypatch.setattr(utils_mod, "_CALLBACK_ACCEPTS_CALL_INFO", {})
snapshot = {
"accepts_param_true": _accepts_litellm_call_info(_CbAcceptsInfo()),
"rejects_param_false": _accepts_litellm_call_info(_CbRejectsInfo()),
"cache_populated": len(utils_mod._CALLBACK_ACCEPTS_CALL_INFO) == 2,
}
assert snapshot == {
"accepts_param_true": True,
"rejects_param_false": False,
"cache_populated": True,
}
def test_accepts_litellm_call_info_error_on_callback_without_hook_raises(monkeypatch):
monkeypatch.setattr(utils_mod, "_CALLBACK_ACCEPTS_CALL_INFO", {})
class _Bad:
pass
with pytest.raises(AttributeError):
_accepts_litellm_call_info(_Bad()) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# _enrich_http_exception_with_guardrail_context
# ---------------------------------------------------------------------------
def test_enrich_http_exception_adds_guardrail_name_and_mode():
detail = {"error": "blocked"}
exc = HTTPException(status_code=400, detail=detail)
cb = MagicMock()
cb.guardrail_name = "presidio"
cb.event_hook = "pre_call"
_enrich_http_exception_with_guardrail_context(exc, cb)
snapshot = {
"error": detail["error"],
"guardrail_name": detail["guardrail_name"],
"guardrail_mode": detail["guardrail_mode"],
}
assert snapshot == {
"error": "blocked",
"guardrail_name": "presidio",
"guardrail_mode": "pre_call",
}
def test_enrich_http_exception_does_not_overwrite_existing_keys():
detail = {"error": "blocked", "guardrail_name": "explicit", "guardrail_mode": "during_call"}
exc = HTTPException(status_code=400, detail=detail)
cb = MagicMock()
cb.guardrail_name = "should-not-overwrite"
cb.event_hook = "should-not-overwrite"
_enrich_http_exception_with_guardrail_context(exc, cb)
assert detail == {"error": "blocked", "guardrail_name": "explicit", "guardrail_mode": "during_call"}
def test_enrich_http_exception_no_op_for_non_http_exception():
other = ValueError("not http")
_enrich_http_exception_with_guardrail_context(other, MagicMock(guardrail_name="g"))
def test_enrich_http_exception_no_op_for_non_dict_detail():
exc = HTTPException(status_code=400, detail="just a string")
_enrich_http_exception_with_guardrail_context(exc, MagicMock(guardrail_name="g"))
assert exc.detail == "just a string"
def test_enrich_http_exception_error_handling_does_not_raise():
"""``_enrich_http_exception_with_guardrail_context`` swallows mismatched
inputs (non-HTTPException, non-dict detail, no guardrail_name) and never
raises verified by passing each pathological input in turn."""
# Bare exception with no detail at all should not blow up.
bare = Exception("bare")
_enrich_http_exception_with_guardrail_context(bare, MagicMock(guardrail_name=None))
# HTTPException with non-dict detail.
s = HTTPException(status_code=500, detail="str-detail")
_enrich_http_exception_with_guardrail_context(s, MagicMock(guardrail_name="g"))
assert s.detail == "str-detail"
def test_enrich_http_exception_with_falsy_attrs_does_not_set():
detail = {"error": "blocked"}
exc = HTTPException(status_code=400, detail=detail)
cb = MagicMock()
cb.guardrail_name = None
cb.event_hook = None
_enrich_http_exception_with_guardrail_context(exc, cb)
assert detail == {"error": "blocked"}
# ---------------------------------------------------------------------------
# on_backoff
# ---------------------------------------------------------------------------
def test_on_backoff_invokes_print_verbose(monkeypatch):
captured = []
monkeypatch.setattr(utils_mod, "print_verbose", lambda s: captured.append(s))
on_backoff({"tries": 3})
snapshot = {"len": len(captured), "first_has_attempt": "attempt" in captured[0], "first_has_3": "3" in captured[0]}
assert snapshot == {"len": 1, "first_has_attempt": True, "first_has_3": True}
def test_on_backoff_missing_tries_key_raises():
with pytest.raises(KeyError):
on_backoff({})
# ---------------------------------------------------------------------------
# jsonify_object
# ---------------------------------------------------------------------------
def test_jsonify_object_serializes_nested_dicts():
src = {"plain": "x", "nested": {"a": 1, "b": 2}, "n": 42}
out = jsonify_object(src)
expected = {"plain": "x", "nested": '{"a": 1, "b": 2}', "n": 42}
assert out == expected
# Source is not mutated.
assert src == {"plain": "x", "nested": {"a": 1, "b": 2}, "n": 42}
def test_jsonify_object_failed_serialization_marks_value(monkeypatch):
class Unserialiseable:
pass
src = {"name": "x", "bad": {"obj": Unserialiseable()}, "count": 1}
out = jsonify_object(src)
assert out == {"name": "x", "bad": "failed-to-serialize-json", "count": 1}
def test_jsonify_object_non_dict_input_raises():
with pytest.raises(AttributeError):
jsonify_object("not a dict") # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# _lookup_deprecated_key
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_lookup_deprecated_key_returns_active_token_id_and_caches(monkeypatch):
from litellm.caching.dual_cache import LimitedSizeOrderedDict
fresh = LimitedSizeOrderedDict(max_size=1000)
monkeypatch.setattr(utils_mod, "_deprecated_key_cache", fresh)
future = datetime.now(timezone.utc) + timedelta(hours=1)
deprecated_row = MagicMock()
deprecated_row.active_token_id = "active-123"
deprecated_row.revoke_at = future
db = MagicMock()
db.litellm_deprecatedverificationtoken.find_first = AsyncMock(return_value=deprecated_row)
result = await _lookup_deprecated_key(db=db, hashed_token="hash-abc")
cached_value = fresh.get("hash-abc")
snapshot = {
"result": result,
"cache_active_token_id": cached_value[0],
"cache_has_3_tuple": isinstance(cached_value, tuple) and len(cached_value) == 3,
}
assert snapshot == {
"result": "active-123",
"cache_active_token_id": "active-123",
"cache_has_3_tuple": True,
}
@pytest.mark.asyncio
async def test_lookup_deprecated_key_returns_none_when_not_found(monkeypatch):
from litellm.caching.dual_cache import LimitedSizeOrderedDict
monkeypatch.setattr(utils_mod, "_deprecated_key_cache", LimitedSizeOrderedDict(max_size=10))
db = MagicMock()
db.litellm_deprecatedverificationtoken.find_first = AsyncMock(return_value=None)
assert await _lookup_deprecated_key(db=db, hashed_token="missing") is None
@pytest.mark.asyncio
async def test_lookup_deprecated_key_db_error_returns_none(monkeypatch):
from litellm.caching.dual_cache import LimitedSizeOrderedDict
monkeypatch.setattr(utils_mod, "_deprecated_key_cache", LimitedSizeOrderedDict(max_size=10))
db = MagicMock()
db.litellm_deprecatedverificationtoken.find_first = AsyncMock(side_effect=RuntimeError("db down"))
result = await _lookup_deprecated_key(db=db, hashed_token="x")
assert result is None
@pytest.mark.asyncio
async def test_lookup_deprecated_key_uses_cache_within_ttl(monkeypatch):
from litellm.caching.dual_cache import LimitedSizeOrderedDict
cache = LimitedSizeOrderedDict(max_size=10)
now_ts = datetime.now(timezone.utc).timestamp()
cache["hashY"] = ("active-from-cache", now_ts + 100, now_ts + 1000)
monkeypatch.setattr(utils_mod, "_deprecated_key_cache", cache)
db = MagicMock()
db.litellm_deprecatedverificationtoken.find_first = AsyncMock(return_value=None)
result = await _lookup_deprecated_key(db=db, hashed_token="hashY")
assert result == "active-from-cache"
db.litellm_deprecatedverificationtoken.find_first.assert_not_called()

View File

@ -0,0 +1,269 @@
"""Pin ``ProxyLogging.post_call_failure_hook``, ``_is_proxy_only_llm_api_error``,
and ``_handle_logging_proxy_only_error``."""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import HTTPException
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import AlertType, ProxyErrorTypes
from litellm.proxy.utils import ProxyLogging
@pytest.fixture(autouse=True)
def _clear_caps_cache():
ProxyLogging._callback_capabilities_cache.clear()
yield
ProxyLogging._callback_capabilities_cache.clear()
# ---------------------------------------------------------------------------
# _is_proxy_only_llm_api_error
# ---------------------------------------------------------------------------
def test_is_proxy_only_llm_api_truth_table(proxy_logging):
"""Pin the truth table of ``_is_proxy_only_llm_api_error`` in a single
snapshot. Covers no-route, non-LLM route, HTTPException on LLM route,
and auth-error short-circuit."""
snapshot = {
"no_route": proxy_logging._is_proxy_only_llm_api_error(
original_exception=Exception(), route=None
),
"non_llm_route": proxy_logging._is_proxy_only_llm_api_error(
original_exception=HTTPException(status_code=429, detail="rate"),
route="/random/path",
),
"http_on_llm_route": proxy_logging._is_proxy_only_llm_api_error(
original_exception=HTTPException(status_code=429, detail="rate"),
route="/chat/completions",
),
"auth_short_circuit": proxy_logging._is_proxy_only_llm_api_error(
original_exception=Exception("auth"),
error_type=ProxyErrorTypes.auth_error,
route="/chat/completions",
),
}
assert snapshot == {
"no_route": False,
"non_llm_route": False,
"http_on_llm_route": True,
"auth_short_circuit": True,
}
def test_is_proxy_only_llm_api_missing_exception_raises(proxy_logging):
"""Passing nothing should TypeError on the missing positional kwarg."""
with pytest.raises(TypeError):
proxy_logging._is_proxy_only_llm_api_error() # type: ignore[call-arg]
# ---------------------------------------------------------------------------
# post_call_failure_hook
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_post_call_failure_hook_no_callbacks_returns_none(
proxy_logging, make_user_api_key_auth, mock_callbacks_disabled
):
proxy_logging.alert_types = []
request_data = {"litellm_call_id": "abc", "model": "m", "messages": []}
out = await proxy_logging.post_call_failure_hook(
request_data=request_data,
original_exception=ValueError("oops"),
user_api_key_dict=make_user_api_key_auth(),
)
snapshot = {
"out_is_none": out is None,
"litellm_logging_obj_popped": "litellm_logging_obj" not in request_data,
"call_id_preserved": request_data["litellm_call_id"] == "abc",
"first_api_call_start_time_present": "first_api_call_start_time" in request_data,
}
assert snapshot == {
"out_is_none": True,
"litellm_logging_obj_popped": True,
"call_id_preserved": True,
"first_api_call_start_time_present": False,
}
@pytest.mark.asyncio
async def test_post_call_failure_hook_callback_returns_http_exception(
proxy_logging, make_user_api_key_auth, monkeypatch
):
transformed = HTTPException(status_code=418, detail="teapot")
class _Cb(CustomLogger):
async def async_post_call_failure_hook(self, **kwargs): # type: ignore[override]
return transformed
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
proxy_logging.alert_types = []
out = await proxy_logging.post_call_failure_hook(
request_data={"litellm_call_id": "abc"},
original_exception=ValueError("oops"),
user_api_key_dict=make_user_api_key_auth(),
)
assert out is transformed
@pytest.mark.asyncio
async def test_post_call_failure_hook_callback_raises_http_exception_first_wins(
proxy_logging, make_user_api_key_auth, monkeypatch
):
err = HTTPException(status_code=418, detail="raised teapot")
class _Cb(CustomLogger):
async def async_post_call_failure_hook(self, **kwargs): # type: ignore[override]
raise err
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
proxy_logging.alert_types = []
out = await proxy_logging.post_call_failure_hook(
request_data={"litellm_call_id": "abc"},
original_exception=ValueError("oops"),
user_api_key_dict=make_user_api_key_auth(),
)
assert out is err
@pytest.mark.asyncio
async def test_post_call_failure_hook_non_http_exception_in_callback_swallowed(
proxy_logging, make_user_api_key_auth, monkeypatch
):
class _Cb(CustomLogger):
async def async_post_call_failure_hook(self, **kwargs): # type: ignore[override]
raise RuntimeError("non-http inside cb")
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
proxy_logging.alert_types = []
out = await proxy_logging.post_call_failure_hook(
request_data={"litellm_call_id": "abc"},
original_exception=ValueError("oops"),
user_api_key_dict=make_user_api_key_auth(),
)
assert out is None
# ---------------------------------------------------------------------------
# _handle_logging_proxy_only_error
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_handle_logging_proxy_only_path_uses_existing_logging_obj(
proxy_logging, make_user_api_key_auth
):
logging_obj = MagicMock()
logging_obj.call_type = "acompletion"
logging_obj.model_call_details = {}
logging_obj.async_failure_handler = AsyncMock()
request_data = {
"litellm_logging_obj": logging_obj,
"messages": [{"role": "user", "content": "x"}],
"model": "m",
"metadata": {},
}
await proxy_logging._handle_logging_proxy_only_error(
request_data=request_data,
user_api_key_dict=make_user_api_key_auth(),
route="/chat/completions",
original_exception=HTTPException(status_code=429, detail="rate"),
)
from litellm.constants import LITELLM_LOGGING_NO_UPSTREAM_LLM_CALL
snapshot = {
"input_logged": "messages" in logging_obj.model_call_details,
"call_type_normalized": logging_obj.call_type,
"marker_present": logging_obj.model_call_details.get(
LITELLM_LOGGING_NO_UPSTREAM_LLM_CALL
)
is True,
"async_failure_called": logging_obj.async_failure_handler.called,
}
assert snapshot == {
"input_logged": True,
"call_type_normalized": "acompletion",
"marker_present": True,
"async_failure_called": True,
}
@pytest.mark.asyncio
async def test_handle_logging_proxy_only_path_skips_for_pass_through(
proxy_logging, make_user_api_key_auth
):
from litellm.types.utils import CallTypes
logging_obj = MagicMock()
logging_obj.call_type = CallTypes.pass_through.value
logging_obj.model_call_details = {}
logging_obj.async_failure_handler = AsyncMock()
logging_obj.pre_call = MagicMock()
request_data = {
"litellm_logging_obj": logging_obj,
"messages": [{"role": "user"}],
"model": "m",
}
await proxy_logging._handle_logging_proxy_only_error(
request_data=request_data,
user_api_key_dict=make_user_api_key_auth(),
route="/chat/completions",
original_exception=HTTPException(status_code=429, detail="rate"),
)
logging_obj.pre_call.assert_not_called()
logging_obj.async_failure_handler.assert_not_called()
@pytest.mark.asyncio
async def test_handle_logging_proxy_only_path_no_logging_obj_creates_one(
proxy_logging, make_user_api_key_auth, monkeypatch
):
fake_logging_obj = MagicMock()
fake_logging_obj.call_type = "acompletion"
fake_logging_obj.model_call_details = {}
fake_logging_obj.async_failure_handler = AsyncMock()
def fake_function_setup(**kwargs):
return fake_logging_obj, {}
monkeypatch.setattr(litellm.utils, "function_setup", fake_function_setup)
request_data = {"messages": [{"role": "user"}], "model": "m"}
await proxy_logging._handle_logging_proxy_only_error(
request_data=request_data,
user_api_key_dict=make_user_api_key_auth(),
route="/chat/completions",
original_exception=HTTPException(status_code=429, detail="rate"),
)
assert "litellm_call_id" in request_data
fake_logging_obj.async_failure_handler.assert_called_once()
@pytest.mark.asyncio
async def test_handle_logging_proxy_only_path_propagates_async_failure_raises(
proxy_logging, make_user_api_key_auth
):
logging_obj = MagicMock()
logging_obj.call_type = "acompletion"
logging_obj.model_call_details = {}
logging_obj.async_failure_handler = AsyncMock(side_effect=RuntimeError("boom"))
request_data = {
"litellm_logging_obj": logging_obj,
"messages": [{"role": "user"}],
"model": "m",
}
with pytest.raises(RuntimeError):
await proxy_logging._handle_logging_proxy_only_error(
request_data=request_data,
user_api_key_dict=make_user_api_key_auth(),
route="/chat/completions",
original_exception=Exception("x"),
)

View File

@ -0,0 +1,97 @@
"""Pin ``ProxyLogging.post_call_success_hook``."""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
import litellm
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.utils import ProxyLogging
from litellm.types.guardrails import GuardrailEventHooks
@pytest.fixture(autouse=True)
def _clear_caps_cache():
ProxyLogging._callback_capabilities_cache.clear()
yield
ProxyLogging._callback_capabilities_cache.clear()
def _make_guardrail(name="g", should_run=True, override=None):
cb = MagicMock(spec=CustomGuardrail)
cb.__class__ = CustomGuardrail
cb.guardrail_name = name
cb.event_hook = GuardrailEventHooks.post_call
cb.should_run_guardrail = MagicMock(return_value=should_run)
cb.async_post_call_success_hook = AsyncMock(return_value=override)
return cb
@pytest.mark.asyncio
async def test_post_call_success_hook_returns_response_when_no_callbacks(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
response = {"original": True, "model": "m", "choices": []}
out = await proxy_logging.post_call_success_hook(
data={}, response=response, user_api_key_dict=make_user_api_key_auth()
)
assert out == {"original": True, "model": "m", "choices": []}
@pytest.mark.asyncio
async def test_post_call_success_hook_runs_other_callback_and_replaces_response(
proxy_logging, make_user_api_key_auth, monkeypatch
):
new_response = {"modified": True, "kept": "yes", "final": "v"}
class _CL(CustomLogger):
async def async_post_call_success_hook(self, **kwargs): # type: ignore[override]
return new_response
monkeypatch.setattr(litellm, "callbacks", [_CL()])
out = await proxy_logging.post_call_success_hook(
data={}, response={"original": True}, user_api_key_dict=make_user_api_key_auth()
)
assert out == new_response
@pytest.mark.asyncio
async def test_post_call_success_hook_guardrail_should_not_run_skipped(
proxy_logging, make_user_api_key_auth, monkeypatch
):
g = _make_guardrail(should_run=False)
monkeypatch.setattr(litellm, "callbacks", [g])
response = MagicMock()
out = await proxy_logging.post_call_success_hook(
data={}, response=response, user_api_key_dict=make_user_api_key_auth()
)
g.async_post_call_success_hook.assert_not_called()
assert out is response
@pytest.mark.asyncio
async def test_post_call_success_hook_guardrail_error_raises(
proxy_logging, make_user_api_key_auth, monkeypatch
):
g = _make_guardrail()
g.async_post_call_success_hook = AsyncMock(side_effect=RuntimeError("blocked"))
monkeypatch.setattr(litellm, "callbacks", [g])
with pytest.raises(RuntimeError):
await proxy_logging.post_call_success_hook(
data={}, response=MagicMock(), user_api_key_dict=make_user_api_key_auth()
)
@pytest.mark.asyncio
async def test_post_call_success_hook_guardrail_returns_modified_response(
proxy_logging, make_user_api_key_auth, monkeypatch
):
modified = {"a": 1, "b": 2, "c": 3}
g = _make_guardrail(override=modified)
monkeypatch.setattr(litellm, "callbacks", [g])
out = await proxy_logging.post_call_success_hook(
data={}, response={"orig": True}, user_api_key_dict=make_user_api_key_auth()
)
assert out == modified

View File

@ -0,0 +1,168 @@
"""Pin ``ProxyLogging.pre_call_hook`` and ``process_pre_call_hook_response``."""
from __future__ import annotations
from typing import Any, Dict
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import HTTPException
import litellm
from litellm.exceptions import RejectedRequestError
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.utils import ProxyLogging
@pytest.fixture(autouse=True)
def _clear_caps_cache():
ProxyLogging._callback_capabilities_cache.clear()
yield
ProxyLogging._callback_capabilities_cache.clear()
# ---------------------------------------------------------------------------
# process_pre_call_hook_response
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_process_pre_call_hook_response_dict_returns_response(proxy_logging):
out = await proxy_logging.process_pre_call_hook_response(
response={"messages": [{"x": 1}], "model": "m", "temperature": 0.5},
data={"original": True},
call_type="completion",
)
assert out == {"messages": [{"x": 1}], "model": "m", "temperature": 0.5}
@pytest.mark.asyncio
async def test_process_pre_call_hook_response_string_completion_raises_rejected(proxy_logging):
with pytest.raises(RejectedRequestError):
await proxy_logging.process_pre_call_hook_response(
response="rejected",
data={"model": "m"},
call_type="completion",
)
@pytest.mark.asyncio
async def test_process_pre_call_hook_response_string_other_call_type_raises_http(proxy_logging):
with pytest.raises(HTTPException) as info:
await proxy_logging.process_pre_call_hook_response(
response="bad",
data={},
call_type="embeddings",
)
assert info.value.status_code == 400
@pytest.mark.asyncio
async def test_process_pre_call_hook_response_exception_reraises(proxy_logging):
err = RuntimeError("hook said no")
with pytest.raises(RuntimeError, match="hook said no"):
await proxy_logging.process_pre_call_hook_response(
response=err, data={}, call_type="completion"
)
@pytest.mark.asyncio
async def test_process_pre_call_hook_response_other_type_returns_data(proxy_logging):
out = await proxy_logging.process_pre_call_hook_response(
response=12345, data={"a": 1, "b": 2, "c": 3}, call_type="completion"
)
assert out == {"a": 1, "b": 2, "c": 3}
# ---------------------------------------------------------------------------
# pre_call_hook
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_pre_call_hook_returns_data_when_no_callbacks(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
data = {"messages": [{"role": "user", "content": "hi"}], "model": "m", "temperature": 0.7}
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
out = await proxy_logging.pre_call_hook(
user_api_key_dict=make_user_api_key_auth(),
data=data,
call_type="completion",
)
assert out is data
@pytest.mark.asyncio
async def test_pre_call_hook_returns_none_for_none_data(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
out = await proxy_logging.pre_call_hook(
user_api_key_dict=make_user_api_key_auth(),
data=None,
call_type="completion",
)
assert out is None
@pytest.mark.asyncio
async def test_pre_call_hook_invokes_pre_call_override(proxy_logging, make_user_api_key_auth, monkeypatch):
captured: Dict[str, Any] = {}
class _Cb(CustomLogger):
async def async_pre_call_hook(self, **kwargs): # type: ignore[override]
captured.update(kwargs)
return {"messages": [{"x": "modified"}], "model": "m", "temperature": 0.1}
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
out = await proxy_logging.pre_call_hook(
user_api_key_dict=make_user_api_key_auth(),
data={"messages": [{"x": "input"}], "model": "m", "temperature": 0.1},
call_type="completion",
)
snapshot = {
"out_messages": out["messages"],
"out_model": out["model"],
"out_temp": out["temperature"],
"cb_received_call_type": captured.get("call_type"),
}
assert snapshot == {
"out_messages": [{"x": "modified"}],
"out_model": "m",
"out_temp": 0.1,
"cb_received_call_type": "completion",
}
@pytest.mark.asyncio
async def test_pre_call_hook_propagates_callback_error_raises(proxy_logging, make_user_api_key_auth, monkeypatch):
class _BadCb(CustomLogger):
async def async_pre_call_hook(self, **kwargs): # type: ignore[override]
raise RuntimeError("rejected")
monkeypatch.setattr(litellm, "callbacks", [_BadCb()])
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
with pytest.raises(RuntimeError, match="rejected"):
await proxy_logging.pre_call_hook(
user_api_key_dict=make_user_api_key_auth(),
data={"model": "m"},
call_type="completion",
)
@pytest.mark.asyncio
async def test_pre_call_hook_processes_guardrail_metadata_when_no_overrides(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
"""Even when no callback overrides exist, ``_process_guardrail_metadata`` runs."""
data = {"messages": [{"role": "user"}], "model": "m", "metadata": {"guardrails": ["g1"]}}
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
invoked = {}
def fake_process(d):
invoked["data"] = d
proxy_logging._process_guardrail_metadata = fake_process # type: ignore[assignment]
out = await proxy_logging.pre_call_hook(
user_api_key_dict=make_user_api_key_auth(),
data=data,
call_type="completion",
)
assert out is data
assert invoked["data"] is data

View File

@ -0,0 +1,432 @@
"""Pin ProxyLogging streaming + response-headers helpers.
Covers ``_wrap_streaming_iterator_with_enrichment``,
``async_post_call_streaming_hook``,
``async_post_call_streaming_iterator_hook``, ``_fire_deferred_stream_logging``,
``is_a2a_streaming_response``, ``_init_response_taking_too_long_task``,
``post_call_response_headers_hook``, ``_build_litellm_call_info``.
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import HTTPException
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.utils import ProxyLogging
@pytest.fixture(autouse=True)
def _clear_caps_cache():
ProxyLogging._callback_capabilities_cache.clear()
yield
ProxyLogging._callback_capabilities_cache.clear()
# ---------------------------------------------------------------------------
# is_a2a_streaming_response
# ---------------------------------------------------------------------------
def test_is_a2a_streaming_response_truth_matrix(proxy_logging):
snapshot = {
"all_three_keys_present": proxy_logging.is_a2a_streaming_response(
{"jsonrpc": "2.0", "id": "1", "result": {"x": 1}, "extra": "y"}
),
"missing_result": proxy_logging.is_a2a_streaming_response(
{"jsonrpc": "2.0", "id": "1"}
),
"missing_jsonrpc": proxy_logging.is_a2a_streaming_response(
{"id": "1", "result": {}}
),
"empty_dict": proxy_logging.is_a2a_streaming_response({}),
}
assert snapshot == {
"all_three_keys_present": True,
"missing_result": False,
"missing_jsonrpc": False,
"empty_dict": False,
}
def test_is_a2a_streaming_response_invalid_input_raises(proxy_logging):
with pytest.raises(TypeError):
proxy_logging.is_a2a_streaming_response(None) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# _build_litellm_call_info
# ---------------------------------------------------------------------------
def test_build_litellm_call_info_pulls_from_hidden_params_and_metadata(proxy_logging):
response = MagicMock()
response._hidden_params = {
"custom_llm_provider": "openai",
"api_base": "https://api.openai.com",
"model_id": "model-1",
}
info = proxy_logging._build_litellm_call_info(
data={"metadata": {"model_info": {"name": "gpt-4o-mini"}}},
response=response,
)
assert info == {
"custom_llm_provider": "openai",
"model_info": {"name": "gpt-4o-mini"},
"api_base": "https://api.openai.com",
"model_id": "model-1",
}
def test_build_litellm_call_info_fallbacks_to_litellm_metadata(proxy_logging):
response = MagicMock()
response._hidden_params = {"custom_llm_provider": "azure"}
info = proxy_logging._build_litellm_call_info(
data={"litellm_metadata": {"model_info": {"alias": "azure-gpt"}}},
response=response,
)
snapshot = {
"custom_llm_provider": info["custom_llm_provider"],
"model_info": info["model_info"],
"api_base": info["api_base"],
"model_id": info["model_id"],
}
assert snapshot == {
"custom_llm_provider": "azure",
"model_info": {"alias": "azure-gpt"},
"api_base": None,
"model_id": None,
}
def test_build_litellm_call_info_invalid_data_raises(proxy_logging):
with pytest.raises(AttributeError):
proxy_logging._build_litellm_call_info(data=None, response=MagicMock()) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# _init_response_taking_too_long_task
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_init_response_taking_too_long_task_runs_when_alerting(proxy_logging):
proxy_logging.slack_alerting_instance = MagicMock()
proxy_logging.slack_alerting_instance.alerting = ["slack"]
captured: Dict[str, Any] = {}
async def fake_resp_too_long(request_data):
captured["request_data"] = request_data
proxy_logging.slack_alerting_instance.response_taking_too_long = fake_resp_too_long
payload = {"req": "y", "litellm_call_id": "c1", "model": "m"}
proxy_logging._init_response_taking_too_long_task(data=payload)
await asyncio.sleep(0)
snapshot = {
"received_payload": captured["request_data"],
"fired_once": len(captured) == 1,
"alerting_was_truthy": bool(proxy_logging.slack_alerting_instance.alerting),
}
assert snapshot == {
"received_payload": payload,
"fired_once": True,
"alerting_was_truthy": True,
}
@pytest.mark.asyncio
async def test_init_response_taking_too_long_task_no_op_when_alerting_off(proxy_logging):
proxy_logging.slack_alerting_instance = MagicMock()
proxy_logging.slack_alerting_instance.alerting = None
proxy_logging.slack_alerting_instance.response_taking_too_long = AsyncMock()
proxy_logging._init_response_taking_too_long_task(data=None)
await asyncio.sleep(0)
proxy_logging.slack_alerting_instance.response_taking_too_long.assert_not_called()
def test_init_response_taking_too_long_task_no_slack_instance_no_error_raises(proxy_logging):
proxy_logging.slack_alerting_instance = None
proxy_logging._init_response_taking_too_long_task(data=None)
# ---------------------------------------------------------------------------
# _wrap_streaming_iterator_with_enrichment
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_wrap_streaming_iterator_with_enrichment_passes_through_chunks(proxy_logging):
async def gen():
for ch in ("a", "b", "c"):
yield ch
cb = MagicMock(guardrail_name="g", event_hook="pre_call")
wrapped = proxy_logging._wrap_streaming_iterator_with_enrichment(callback=cb, gen=gen())
out = [ch async for ch in wrapped]
snapshot = {
"chunks": out,
"count": len(out),
"first": out[0],
"last": out[-1],
}
assert snapshot == {
"chunks": ["a", "b", "c"],
"count": 3,
"first": "a",
"last": "c",
}
@pytest.mark.asyncio
async def test_wrap_streaming_iterator_with_enrichment_enriches_http_exception_raises(proxy_logging):
detail = {"error": "blocked"}
async def boom_gen():
if False:
yield # pragma: no cover
raise HTTPException(status_code=400, detail=detail)
cb = MagicMock(guardrail_name="presidio", event_hook="post_call")
wrapped = proxy_logging._wrap_streaming_iterator_with_enrichment(callback=cb, gen=boom_gen())
with pytest.raises(HTTPException):
async for _ in wrapped:
pass
assert detail["guardrail_name"] == "presidio"
assert detail["guardrail_mode"] == "post_call"
# ---------------------------------------------------------------------------
# async_post_call_streaming_hook
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_async_post_call_streaming_hook_fast_path_returns_response(proxy_logging, mock_callbacks_disabled, make_user_api_key_auth):
resp = "chunk-1"
out = await proxy_logging.async_post_call_streaming_hook(
data={}, response=resp, user_api_key_dict=make_user_api_key_auth()
)
snapshot = {
"out_is_input": out is resp,
"out_value": out,
"type": type(out).__name__,
"callbacks_empty": len(litellm.callbacks) == 0,
}
assert snapshot == {
"out_is_input": True,
"out_value": "chunk-1",
"type": "str",
"callbacks_empty": True,
}
@pytest.mark.asyncio
async def test_async_post_call_streaming_hook_invokes_per_chunk_callback(proxy_logging, make_user_api_key_auth, monkeypatch):
class _Per(CustomLogger):
async def async_post_call_streaming_hook(self, **kwargs): # type: ignore[override]
return "modified-" + str(kwargs.get("response", ""))
cb = _Per()
monkeypatch.setattr(litellm, "callbacks", [cb])
from litellm import ModelResponse
fake_resp = ModelResponse(
id="rid",
choices=[{"index": 0, "delta": {"role": "assistant", "content": "hi"}, "finish_reason": None}],
created=0,
model="gpt-4o-mini",
object="chat.completion.chunk",
)
out = await proxy_logging.async_post_call_streaming_hook(
data={},
response=fake_resp,
user_api_key_dict=make_user_api_key_auth(),
)
assert isinstance(out, str)
assert out.startswith("modified-")
@pytest.mark.asyncio
async def test_async_post_call_streaming_hook_callback_error_raises(proxy_logging, make_user_api_key_auth, monkeypatch):
class _Per(CustomLogger):
async def async_post_call_streaming_hook(self, **kwargs): # type: ignore[override]
raise RuntimeError("hook-fail")
monkeypatch.setattr(litellm, "callbacks", [_Per()])
from litellm import ModelResponse
fake_resp = ModelResponse(
id="rid",
choices=[{"index": 0, "delta": {"role": "assistant", "content": "hi"}, "finish_reason": None}],
created=0,
model="gpt-4o-mini",
object="chat.completion.chunk",
)
with pytest.raises(RuntimeError):
await proxy_logging.async_post_call_streaming_hook(
data={},
response=fake_resp,
user_api_key_dict=make_user_api_key_auth(),
)
# ---------------------------------------------------------------------------
# async_post_call_streaming_iterator_hook
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_async_post_call_streaming_iterator_hook_no_overrides_passes_through(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
async def gen():
for ch in ("a", "b"):
yield ch
chunks = []
async for ch in proxy_logging.async_post_call_streaming_iterator_hook(
response=gen(),
user_api_key_dict=make_user_api_key_auth(),
request_data={},
):
chunks.append(ch)
snapshot = {
"chunks": chunks,
"count": len(chunks),
"passthrough_preserved_order": chunks == ["a", "b"],
}
assert snapshot == {
"chunks": ["a", "b"],
"count": 2,
"passthrough_preserved_order": True,
}
@pytest.mark.asyncio
async def test_async_post_call_streaming_iterator_hook_with_override_chains_callback(proxy_logging, make_user_api_key_auth, monkeypatch):
class _IterOverride(CustomLogger):
async def async_post_call_streaming_iterator_hook(self, **kwargs): # type: ignore[override]
async for ch in kwargs["response"]:
yield ch + "*"
monkeypatch.setattr(litellm, "callbacks", [_IterOverride()])
async def gen():
for ch in ("a", "b"):
yield ch
out: List[str] = []
async for ch in proxy_logging.async_post_call_streaming_iterator_hook(
response=gen(),
user_api_key_dict=make_user_api_key_auth(),
request_data={},
):
out.append(ch)
assert out == ["a*", "b*"]
@pytest.mark.asyncio
async def test_async_post_call_streaming_iterator_hook_upstream_error_raises(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
async def gen():
if False:
yield # pragma: no cover
raise RuntimeError("upstream")
with pytest.raises(RuntimeError):
async for _ in proxy_logging.async_post_call_streaming_iterator_hook(
response=gen(),
user_api_key_dict=make_user_api_key_auth(),
request_data={},
):
pass
# ---------------------------------------------------------------------------
# _fire_deferred_stream_logging
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_fire_deferred_stream_logging_fires_callback():
logging_obj = MagicMock()
captured: Dict[str, Any] = {}
async def deferred(arg):
captured["arg"] = arg
logging_obj._on_deferred_stream_complete = deferred
logging_obj._deferred_stream_complete_args = ("payload",)
ProxyLogging._fire_deferred_stream_logging(request_data={"litellm_logging_obj": logging_obj})
await asyncio.sleep(0)
snapshot = {
"arg": captured["arg"],
"callback_cleared": logging_obj._on_deferred_stream_complete is None,
"args_cleared": logging_obj._deferred_stream_complete_args is None,
}
assert snapshot == {"arg": "payload", "callback_cleared": True, "args_cleared": True}
def test_fire_deferred_stream_logging_no_logging_obj_no_error():
ProxyLogging._fire_deferred_stream_logging(request_data={})
def test_fire_deferred_stream_logging_missing_obj_raises_on_invalid_dict():
with pytest.raises(AttributeError):
ProxyLogging._fire_deferred_stream_logging(request_data=None) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# post_call_response_headers_hook
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_post_call_response_headers_hook_returns_empty_when_no_callbacks(
proxy_logging, mock_callbacks_disabled, make_user_api_key_auth
):
out = await proxy_logging.post_call_response_headers_hook(
data={}, user_api_key_dict=make_user_api_key_auth(), response=MagicMock(_hidden_params={})
)
assert out == {}
@pytest.mark.asyncio
async def test_post_call_response_headers_hook_merges_callback_headers(proxy_logging, make_user_api_key_auth, monkeypatch):
class _Cb(CustomLogger):
async def async_post_call_response_headers_hook(self, **kwargs): # type: ignore[override]
return {"X-One": "1", "X-Two": "2", "X-Common": "first"}
class _Cb2(CustomLogger):
async def async_post_call_response_headers_hook(self, **kwargs): # type: ignore[override]
return {"X-Common": "second", "X-Three": "3"}
monkeypatch.setattr(litellm, "callbacks", [_Cb(), _Cb2()])
response = MagicMock()
response._hidden_params = {}
out = await proxy_logging.post_call_response_headers_hook(
data={}, user_api_key_dict=make_user_api_key_auth(), response=response
)
assert out == {"X-One": "1", "X-Two": "2", "X-Common": "second", "X-Three": "3"}
@pytest.mark.asyncio
async def test_post_call_response_headers_hook_swallows_callback_error(proxy_logging, make_user_api_key_auth, monkeypatch):
"""Errors inside the hook are caught — function returns merged so-far."""
class _Cb(CustomLogger):
async def async_post_call_response_headers_hook(self, **kwargs): # type: ignore[override]
raise RuntimeError("bad header")
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
response = MagicMock()
response._hidden_params = {}
out = await proxy_logging.post_call_response_headers_hook(
data={}, user_api_key_dict=make_user_api_key_auth(), response=response
)
assert out == {}