test(proxy/utils): pin ProxyLogging behavior (#29485)
* test(proxy/utils): pin ProxyLogging behavior Add behavior-pinning tests for the ProxyLogging cluster in litellm/proxy/utils.py under tests/test_litellm/proxy/utils/proxy_logging/. Covers InternalUsageCache, _CallbackCapabilities, top-of-file helpers (print_verbose, _get_email_logger_class, _accepts_litellm_call_info, _enrich_http_exception_with_guardrail_context), the full ProxyLogging class (lifecycle, MCP-LLM bridging, capability probes, guardrail pipeline, pre/during/post/streaming hooks, alerting), plus the bottom-of-region helpers (on_backoff, jsonify_object, _lookup_deprecated_key). Each pinned symbol has happy-path and error-path coverage; happy paths use direct dict-equality with three or more keys (or HiddenParams / Pydantic model_validate where the surface is a Pydantic shape). The subdirectory carries a local _pin_check.py and _coverage_check.py that enforce the gate without surfacing numeric thresholds in CI logs. Wires tests/test_litellm/proxy/utils into the existing test-path block in .github/workflows/test-unit-proxy-endpoints.yml. * test(proxy/utils): drop unused mock_httpx_client fixture Declared in conftest.py but never referenced by any test. Removing the dead fixture per Greptile P2 feedback. * test(proxy/utils): drop local-only gate scripts from PR _pin_check.py and _coverage_check.py are local stopping signals (not wired into CI, consume a gitignored .pin_list.txt). They served their purpose telling the engineer when to stop writing tests; the pytest suite is the artifact that belongs in the repo. --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
457f65eff9
commit
b175990b4a
@ -0,0 +1,56 @@
|
||||
"""Sanity tests for the proxy_logging conftest fixtures.
|
||||
|
||||
Excluded from the pin-check by name.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_normalize_replaces_volatile_keys(normalize_fn):
|
||||
raw = {"id": 7, "name": "x", "nested": {"created_at": 1, "value": 2}}
|
||||
expected = {"id": "<VOLATILE>", "name": "x", "nested": {"created_at": "<VOLATILE>", "value": 2}}
|
||||
assert normalize_fn(raw) == expected
|
||||
|
||||
|
||||
def test_normalize_handles_lists(normalize_fn):
|
||||
raw = [{"id": 1}, {"id": 2}]
|
||||
assert normalize_fn(raw) == [{"id": "<VOLATILE>"}, {"id": "<VOLATILE>"}]
|
||||
|
||||
|
||||
def test_mock_dual_cache_is_dual_cache(mock_dual_cache):
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
assert isinstance(mock_dual_cache, DualCache)
|
||||
|
||||
|
||||
def test_make_user_api_key_auth_returns_correct_type(make_user_api_key_auth):
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
auth = make_user_api_key_auth()
|
||||
assert isinstance(auth, UserAPIKeyAuth)
|
||||
assert auth.user_id == "test-user"
|
||||
|
||||
|
||||
def test_make_user_api_key_auth_overrides_apply(make_user_api_key_auth):
|
||||
auth = make_user_api_key_auth(user_id="custom-id")
|
||||
assert auth.user_id == "custom-id"
|
||||
|
||||
|
||||
def test_proxy_logging_fixture_is_initialized(proxy_logging):
|
||||
from litellm.proxy.utils import InternalUsageCache, ProxyLogging
|
||||
|
||||
assert isinstance(proxy_logging, ProxyLogging)
|
||||
assert isinstance(proxy_logging.internal_usage_cache, InternalUsageCache)
|
||||
assert proxy_logging.proxy_hook_mapping == {}
|
||||
|
||||
|
||||
def test_make_mcp_request_obj_default(make_mcp_request_obj):
|
||||
obj = make_mcp_request_obj()
|
||||
assert obj.tool_name == "calculator"
|
||||
assert obj.arguments == {"x": 1, "y": 2}
|
||||
|
||||
|
||||
def test_mock_router_has_guardrail_list(mock_router):
|
||||
assert mock_router.guardrail_list == []
|
||||
136
tests/test_litellm/proxy/utils/proxy_logging/conftest.py
Normal file
136
tests/test_litellm/proxy/utils/proxy_logging/conftest.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""Shared fixtures for tests/test_litellm/proxy/utils/proxy_logging/.
|
||||
|
||||
All fixtures used by PR1 of the proxy/utils.py behavior-pinning project
|
||||
live here. Tests should not declare fixtures inline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[5]))
|
||||
|
||||
|
||||
VOLATILE_KEYS = frozenset(
|
||||
{
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"id",
|
||||
"request_id",
|
||||
"token",
|
||||
"expires",
|
||||
"expires_at",
|
||||
"litellm_call_id",
|
||||
"key_alias",
|
||||
"created",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"duration",
|
||||
"guardrail_start_time",
|
||||
"guardrail_end_time",
|
||||
"guardrail_duration",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def normalize(data: Any, volatile: frozenset = VOLATILE_KEYS) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
k: ("<VOLATILE>" if k in volatile else normalize(v, volatile))
|
||||
for k, v in data.items()
|
||||
}
|
||||
if isinstance(data, list):
|
||||
return [normalize(v, volatile) for v in data]
|
||||
return data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dual_cache():
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
cache = DualCache(default_in_memory_ttl=1)
|
||||
return cache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router():
|
||||
router = MagicMock()
|
||||
router.guardrail_list = []
|
||||
router.get_available_guardrail = MagicMock(return_value={"callback": None})
|
||||
return router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_callbacks_disabled(monkeypatch):
|
||||
"""Disable all litellm callbacks for the duration of a test."""
|
||||
import litellm
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [])
|
||||
monkeypatch.setattr(litellm, "success_callback", [])
|
||||
monkeypatch.setattr(litellm, "failure_callback", [])
|
||||
monkeypatch.setattr(litellm, "_async_success_callback", [])
|
||||
monkeypatch.setattr(litellm, "_async_failure_callback", [])
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_user_api_key_auth():
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
def _make(**overrides) -> UserAPIKeyAuth:
|
||||
defaults: Dict[str, Any] = {
|
||||
"api_key": "sk-test-1234",
|
||||
"user_id": "test-user",
|
||||
"team_id": "test-team",
|
||||
"user_role": None,
|
||||
"max_budget": None,
|
||||
"spend": 0.0,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return UserAPIKeyAuth(**defaults)
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def proxy_logging(mock_callbacks_disabled):
|
||||
"""A wired-up ProxyLogging instance backed by a fresh DualCache.
|
||||
|
||||
The fixture leaves it un-started; tests that need ``startup_event``
|
||||
should call it explicitly with the deps they want to control.
|
||||
"""
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
return ProxyLogging(user_api_key_cache=UserApiKeyCache())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def normalize_fn():
|
||||
return normalize
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_mcp_request_obj():
|
||||
from litellm.types.llms.base import HiddenParams
|
||||
from litellm.types.mcp import MCPPreCallRequestObject
|
||||
|
||||
def _make(
|
||||
tool_name: str = "calculator",
|
||||
arguments: Optional[dict] = None,
|
||||
server_name: Optional[str] = "math-server",
|
||||
) -> MCPPreCallRequestObject:
|
||||
return MCPPreCallRequestObject(
|
||||
tool_name=tool_name,
|
||||
arguments=arguments if arguments is not None else {"x": 1, "y": 2},
|
||||
server_name=server_name,
|
||||
user_api_key_auth={},
|
||||
hidden_params=HiddenParams(),
|
||||
)
|
||||
|
||||
return _make
|
||||
262
tests/test_litellm/proxy/utils/proxy_logging/test_alerting.py
Normal file
262
tests/test_litellm/proxy/utils/proxy_logging/test_alerting.py
Normal file
@ -0,0 +1,262 @@
|
||||
"""Pin alerting helpers on ``ProxyLogging``.
|
||||
|
||||
Covers ``failed_tracking_alert``, ``budget_alerts``, ``alerting_handler``,
|
||||
``failure_handler``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import AlertType, CallInfo
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# failed_tracking_alert
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_tracking_alert_no_op_when_alerting_none(proxy_logging):
|
||||
proxy_logging.alerting = None
|
||||
proxy_logging.slack_alerting_instance = MagicMock(failed_tracking_alert=AsyncMock())
|
||||
await proxy_logging.failed_tracking_alert(error_message="x", failing_model="m")
|
||||
proxy_logging.slack_alerting_instance.failed_tracking_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_tracking_alert_forwards_to_slack(proxy_logging):
|
||||
proxy_logging.alerting = ["slack"]
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_alert(**kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
proxy_logging.slack_alerting_instance = MagicMock(failed_tracking_alert=fake_alert)
|
||||
await proxy_logging.failed_tracking_alert(error_message="db down", failing_model="gpt-4")
|
||||
snapshot = {
|
||||
"error_message": captured["error_message"],
|
||||
"failing_model": captured["failing_model"],
|
||||
"captured_keys": sorted(captured.keys()),
|
||||
}
|
||||
assert snapshot == {
|
||||
"error_message": "db down",
|
||||
"failing_model": "gpt-4",
|
||||
"captured_keys": ["error_message", "failing_model"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_tracking_alert_slack_error_raises(proxy_logging):
|
||||
proxy_logging.alerting = ["slack"]
|
||||
proxy_logging.slack_alerting_instance = MagicMock(
|
||||
failed_tracking_alert=AsyncMock(side_effect=RuntimeError("slack down"))
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
await proxy_logging.failed_tracking_alert(error_message="x", failing_model="m")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# budget_alerts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _user_info(alert_emails=None):
|
||||
return CallInfo(
|
||||
spend=0.0,
|
||||
max_budget=1.0,
|
||||
token="tok",
|
||||
user_id="u1",
|
||||
team_id="t1",
|
||||
team_alias=None,
|
||||
user_email=None,
|
||||
key_alias=None,
|
||||
projected_exceeded_date=None,
|
||||
projected_spend=None,
|
||||
event_group="user",
|
||||
event="threshold_crossed",
|
||||
alert_emails=alert_emails,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_alerts_no_op_when_alerting_off_and_no_emails(proxy_logging):
|
||||
proxy_logging.alerting = None
|
||||
proxy_logging.slack_alerting_instance = MagicMock(budget_alerts=AsyncMock())
|
||||
proxy_logging.email_logging_instance = MagicMock(budget_alerts=AsyncMock())
|
||||
await proxy_logging.budget_alerts(type="user_budget", user_info=_user_info())
|
||||
proxy_logging.slack_alerting_instance.budget_alerts.assert_not_called()
|
||||
proxy_logging.email_logging_instance.budget_alerts.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_alerts_slack_when_slack_alerting(proxy_logging):
|
||||
proxy_logging.alerting = ["slack"]
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_alert(**kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
proxy_logging.slack_alerting_instance = MagicMock(budget_alerts=fake_alert)
|
||||
proxy_logging.email_logging_instance = None
|
||||
user_info = _user_info()
|
||||
await proxy_logging.budget_alerts(type="user_budget", user_info=user_info)
|
||||
snapshot = {
|
||||
"type": captured["type"],
|
||||
"user_info_is_callinfo": isinstance(captured["user_info"], CallInfo),
|
||||
"user_id": captured["user_info"].user_id,
|
||||
}
|
||||
assert snapshot == {"type": "user_budget", "user_info_is_callinfo": True, "user_id": "u1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_alerts_soft_budget_with_alert_emails_bypasses_global(proxy_logging):
|
||||
proxy_logging.alerting = None
|
||||
proxy_logging.slack_alerting_instance = MagicMock(budget_alerts=AsyncMock())
|
||||
proxy_logging.email_logging_instance = MagicMock(budget_alerts=AsyncMock())
|
||||
info = _user_info(alert_emails=["a@b.c"])
|
||||
await proxy_logging.budget_alerts(type="soft_budget", user_info=info)
|
||||
proxy_logging.email_logging_instance.budget_alerts.assert_called_once()
|
||||
proxy_logging.slack_alerting_instance.budget_alerts.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_alerts_slack_failure_raises(proxy_logging):
|
||||
proxy_logging.alerting = ["slack"]
|
||||
proxy_logging.slack_alerting_instance = MagicMock(
|
||||
budget_alerts=AsyncMock(side_effect=ConnectionError("slack"))
|
||||
)
|
||||
proxy_logging.email_logging_instance = None
|
||||
with pytest.raises(ConnectionError):
|
||||
await proxy_logging.budget_alerts(type="user_budget", user_info=_user_info())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# alerting_handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alerting_handler_no_op_when_alerting_is_none(proxy_logging):
|
||||
proxy_logging.alerting = None
|
||||
proxy_logging.slack_alerting_instance = MagicMock(send_alert=AsyncMock())
|
||||
await proxy_logging.alerting_handler(message="x", level="High", alert_type=AlertType.db_exceptions)
|
||||
proxy_logging.slack_alerting_instance.send_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alerting_handler_sends_to_slack(proxy_logging):
|
||||
proxy_logging.alerting = ["slack"]
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_send(**kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
proxy_logging.slack_alerting_instance = MagicMock(send_alert=fake_send)
|
||||
await proxy_logging.alerting_handler(
|
||||
message="hi", level="High", alert_type=AlertType.db_exceptions, request_data={"metadata": {}}
|
||||
)
|
||||
snapshot = {
|
||||
"message": captured["message"],
|
||||
"level": captured["level"],
|
||||
"alert_type": captured["alert_type"],
|
||||
"user_info": captured["user_info"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"message": "hi",
|
||||
"level": "High",
|
||||
"alert_type": AlertType.db_exceptions,
|
||||
"user_info": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alerting_handler_sentry_without_sdk_error_raises(proxy_logging, monkeypatch):
|
||||
proxy_logging.alerting = ["sentry"]
|
||||
monkeypatch.setattr(litellm.utils, "sentry_sdk_instance", None)
|
||||
with pytest.raises(Exception, match="SENTRY_DSN"):
|
||||
await proxy_logging.alerting_handler(message="x", level="Low", alert_type=AlertType.db_exceptions)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# failure_handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_handler_skips_when_db_exceptions_not_in_alert_types(proxy_logging):
|
||||
proxy_logging.alert_types = ["llm_too_slow"] # type: ignore[list-item]
|
||||
proxy_logging.alerting_handler = AsyncMock()
|
||||
proxy_logging.service_logging_obj = MagicMock(async_service_failure_hook=AsyncMock())
|
||||
await proxy_logging.failure_handler(original_exception=Exception("x"), duration=1.0, call_type="db_read")
|
||||
proxy_logging.alerting_handler.assert_not_called()
|
||||
proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_handler_logs_db_error_and_calls_service_logging(proxy_logging, monkeypatch):
|
||||
proxy_logging.alert_types = [AlertType.db_exceptions]
|
||||
proxy_logging.alerting_handler = AsyncMock()
|
||||
proxy_logging.service_logging_obj = MagicMock(async_service_failure_hook=AsyncMock())
|
||||
monkeypatch.setattr(litellm.utils, "capture_exception", None)
|
||||
await proxy_logging.failure_handler(
|
||||
original_exception=HTTPException(status_code=500, detail="boom"),
|
||||
duration=1.5,
|
||||
call_type="db_write",
|
||||
)
|
||||
call_kwargs = proxy_logging.service_logging_obj.async_service_failure_hook.call_args.kwargs
|
||||
snapshot = {
|
||||
"service": call_kwargs["service"].value if hasattr(call_kwargs["service"], "value") else call_kwargs["service"],
|
||||
"duration": call_kwargs["duration"],
|
||||
"call_type": call_kwargs["call_type"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"service": "postgres",
|
||||
"duration": 1.5,
|
||||
"call_type": "db_write",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_handler_with_capture_exception_invoked(proxy_logging, monkeypatch):
|
||||
proxy_logging.alert_types = [AlertType.db_exceptions]
|
||||
proxy_logging.alerting_handler = AsyncMock()
|
||||
proxy_logging.service_logging_obj = MagicMock(async_service_failure_hook=AsyncMock())
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def fake_capture(error):
|
||||
captured["error"] = error
|
||||
|
||||
monkeypatch.setattr(litellm.utils, "capture_exception", fake_capture)
|
||||
err = RuntimeError("real")
|
||||
await proxy_logging.failure_handler(original_exception=err, duration=1.0, call_type="db_read")
|
||||
snapshot = {
|
||||
"captured_is_input": captured["error"] is err,
|
||||
"service_failure_called": proxy_logging.service_logging_obj.async_service_failure_hook.called,
|
||||
"alerting_handler_scheduled": proxy_logging.alerting_handler.called,
|
||||
}
|
||||
assert snapshot == {
|
||||
"captured_is_input": True,
|
||||
"service_failure_called": True,
|
||||
"alerting_handler_scheduled": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_handler_propagates_service_logging_error_raises(proxy_logging, monkeypatch):
|
||||
proxy_logging.alert_types = [AlertType.db_exceptions]
|
||||
proxy_logging.alerting_handler = AsyncMock()
|
||||
proxy_logging.service_logging_obj = MagicMock(
|
||||
async_service_failure_hook=AsyncMock(side_effect=RuntimeError("svc"))
|
||||
)
|
||||
monkeypatch.setattr(litellm.utils, "capture_exception", None)
|
||||
with pytest.raises(RuntimeError):
|
||||
await proxy_logging.failure_handler(
|
||||
original_exception=Exception("x"), duration=0.0, call_type="db_read"
|
||||
)
|
||||
@ -0,0 +1,338 @@
|
||||
"""Pin the ``ProxyLogging`` capability-probe family.
|
||||
|
||||
Covers ``_callback_capabilities`` (the cached deriver),
|
||||
``has_post_call_response_headers_callbacks``, ``has_streaming_callbacks``,
|
||||
``has_streaming_chunk_hook_overrides``, ``needs_iterator_wrap``,
|
||||
``needs_per_chunk_streaming_hook``, ``has_during_call_guardrails``, and
|
||||
``get_combined_callback_list``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.utils import ProxyLogging, _CallbackCapabilities
|
||||
|
||||
|
||||
class _PlainLogger(CustomLogger):
|
||||
pass
|
||||
|
||||
|
||||
class _OverridesResponseHeaders(CustomLogger):
|
||||
async def async_post_call_response_headers_hook(self, *args, **kwargs): # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
class _OverridesIterator(CustomLogger):
|
||||
async def async_post_call_streaming_iterator_hook(self, *args, **kwargs): # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
class _OverridesPerChunk(CustomLogger):
|
||||
async def async_post_call_streaming_hook(self, *args, **kwargs): # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
class _OverridesPreCall(CustomLogger):
|
||||
async def async_pre_call_hook(self, *args, **kwargs): # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_caps_cache():
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
yield
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
|
||||
def test_callback_capabilities_with_no_callbacks_returns_defaults(mock_callbacks_disabled):
|
||||
caps = ProxyLogging._callback_capabilities()
|
||||
snapshot = {
|
||||
"headers": caps.has_post_call_response_headers,
|
||||
"iterator": caps.has_iterator_override,
|
||||
"chunk": caps.has_streaming_chunk_override,
|
||||
"guardrail": caps.has_guardrail,
|
||||
"pre_call": caps.has_pre_call_override,
|
||||
"callbacks": caps.resolved_callbacks,
|
||||
"overrides": caps.iterator_overrides,
|
||||
}
|
||||
assert snapshot == {
|
||||
"headers": False,
|
||||
"iterator": False,
|
||||
"chunk": False,
|
||||
"guardrail": False,
|
||||
"pre_call": False,
|
||||
"callbacks": (),
|
||||
"overrides": (),
|
||||
}
|
||||
|
||||
|
||||
def test_callback_capabilities_detects_overrides(monkeypatch):
|
||||
cb1 = _OverridesResponseHeaders()
|
||||
cb2 = _OverridesIterator()
|
||||
cb3 = _OverridesPerChunk()
|
||||
cb4 = _OverridesPreCall()
|
||||
monkeypatch.setattr(litellm, "callbacks", [cb1, cb2, cb3, cb4])
|
||||
|
||||
caps = ProxyLogging._callback_capabilities()
|
||||
snapshot = {
|
||||
"headers": caps.has_post_call_response_headers,
|
||||
"iterator": caps.has_iterator_override,
|
||||
"chunk": caps.has_streaming_chunk_override,
|
||||
"pre_call": caps.has_pre_call_override,
|
||||
}
|
||||
assert snapshot == {
|
||||
"headers": True,
|
||||
"iterator": True,
|
||||
"chunk": True,
|
||||
"pre_call": True,
|
||||
}
|
||||
|
||||
|
||||
def test_callback_capabilities_caches_result(monkeypatch):
|
||||
cb = _OverridesResponseHeaders()
|
||||
monkeypatch.setattr(litellm, "callbacks", [cb])
|
||||
first = ProxyLogging._callback_capabilities()
|
||||
second = ProxyLogging._callback_capabilities()
|
||||
assert first is second
|
||||
|
||||
|
||||
def test_callback_capabilities_invalidates_on_change(monkeypatch):
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesResponseHeaders()])
|
||||
first = ProxyLogging._callback_capabilities()
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
|
||||
second = ProxyLogging._callback_capabilities()
|
||||
assert first is not second
|
||||
assert first.has_post_call_response_headers is True
|
||||
assert second.has_post_call_response_headers is False
|
||||
assert second.has_iterator_override is True
|
||||
|
||||
|
||||
def test_callback_capabilities_callback_resolution_error_raises(monkeypatch):
|
||||
monkeypatch.setattr(litellm, "callbacks", ["unknown-string"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"get_custom_logger_compatible_class",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("bad")),
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
ProxyLogging._callback_capabilities()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Individual capability probes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_has_post_call_response_headers_callbacks_truth_table(monkeypatch, mock_callbacks_disabled):
|
||||
"""One snapshot covering true + false + cache invalidation."""
|
||||
snapshot = {
|
||||
"empty_returns_false": ProxyLogging.has_post_call_response_headers_callbacks(),
|
||||
}
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesResponseHeaders()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["override_returns_true"] = ProxyLogging.has_post_call_response_headers_callbacks()
|
||||
monkeypatch.setattr(litellm, "callbacks", [_PlainLogger()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["plain_logger_false"] = ProxyLogging.has_post_call_response_headers_callbacks()
|
||||
assert snapshot == {
|
||||
"empty_returns_false": False,
|
||||
"override_returns_true": True,
|
||||
"plain_logger_false": False,
|
||||
}
|
||||
|
||||
|
||||
def test_has_post_call_response_headers_callbacks_error_when_bad_callback(monkeypatch):
|
||||
monkeypatch.setattr(litellm, "callbacks", ["x"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"get_custom_logger_compatible_class",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("kaboom")),
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
ProxyLogging.has_post_call_response_headers_callbacks()
|
||||
|
||||
|
||||
def test_has_streaming_callbacks_truth_table(monkeypatch, mock_callbacks_disabled):
|
||||
snapshot = {
|
||||
"empty_false": ProxyLogging.has_streaming_callbacks(),
|
||||
}
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["iterator_override_true"] = ProxyLogging.has_streaming_callbacks()
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesPerChunk()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["per_chunk_override_true"] = ProxyLogging.has_streaming_callbacks()
|
||||
assert snapshot == {
|
||||
"empty_false": False,
|
||||
"iterator_override_true": True,
|
||||
"per_chunk_override_true": True,
|
||||
}
|
||||
|
||||
|
||||
def test_has_streaming_callbacks_error_when_resolution_fails(monkeypatch):
|
||||
monkeypatch.setattr(litellm, "callbacks", ["x"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"get_custom_logger_compatible_class",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(ValueError("nope")),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
ProxyLogging.has_streaming_callbacks()
|
||||
|
||||
|
||||
def test_has_streaming_chunk_hook_overrides_truth_table(monkeypatch, mock_callbacks_disabled):
|
||||
snapshot = {
|
||||
"empty_false": ProxyLogging.has_streaming_chunk_hook_overrides(),
|
||||
}
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesPerChunk()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["per_chunk_override_true"] = ProxyLogging.has_streaming_chunk_hook_overrides()
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["only_iterator_false"] = ProxyLogging.has_streaming_chunk_hook_overrides()
|
||||
assert snapshot == {
|
||||
"empty_false": False,
|
||||
"per_chunk_override_true": True,
|
||||
"only_iterator_false": False,
|
||||
}
|
||||
|
||||
|
||||
def test_has_streaming_chunk_hook_overrides_error_raises(monkeypatch):
|
||||
monkeypatch.setattr(litellm, "callbacks", ["x"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"get_custom_logger_compatible_class",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(TypeError("nope")),
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
ProxyLogging.has_streaming_chunk_hook_overrides()
|
||||
|
||||
|
||||
def test_needs_iterator_wrap_truth_table(proxy_logging, monkeypatch, mock_callbacks_disabled):
|
||||
snapshot = {
|
||||
"empty_false": proxy_logging.needs_iterator_wrap(),
|
||||
}
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["with_iter_override_true"] = proxy_logging.needs_iterator_wrap()
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesPerChunk()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["only_per_chunk_false"] = proxy_logging.needs_iterator_wrap()
|
||||
assert snapshot == {
|
||||
"empty_false": False,
|
||||
"with_iter_override_true": True,
|
||||
"only_per_chunk_false": False,
|
||||
}
|
||||
|
||||
|
||||
def test_needs_iterator_wrap_error_raises(proxy_logging, monkeypatch):
|
||||
monkeypatch.setattr(litellm, "callbacks", ["x"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"get_custom_logger_compatible_class",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("oops")),
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
proxy_logging.needs_iterator_wrap()
|
||||
|
||||
|
||||
def test_needs_per_chunk_streaming_hook_truth_table(proxy_logging, monkeypatch, mock_callbacks_disabled):
|
||||
snapshot = {
|
||||
"empty_false": proxy_logging.needs_per_chunk_streaming_hook(),
|
||||
}
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesPerChunk()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["per_chunk_override_true"] = proxy_logging.needs_per_chunk_streaming_hook()
|
||||
monkeypatch.setattr(litellm, "callbacks", [_OverridesIterator()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["only_iter_override_false"] = proxy_logging.needs_per_chunk_streaming_hook()
|
||||
assert snapshot == {
|
||||
"empty_false": False,
|
||||
"per_chunk_override_true": True,
|
||||
"only_iter_override_false": False,
|
||||
}
|
||||
|
||||
|
||||
def test_needs_per_chunk_streaming_hook_error_raises(proxy_logging, monkeypatch):
|
||||
monkeypatch.setattr(litellm, "callbacks", ["x"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"get_custom_logger_compatible_class",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(KeyError("oops")),
|
||||
)
|
||||
with pytest.raises(KeyError):
|
||||
proxy_logging.needs_per_chunk_streaming_hook()
|
||||
|
||||
|
||||
def test_has_during_call_guardrails_truth_table(monkeypatch, mock_callbacks_disabled):
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
||||
class _G(CustomGuardrail):
|
||||
def __init__(self):
|
||||
super().__init__(guardrail_name="g", event_hook="pre_call")
|
||||
|
||||
snapshot = {
|
||||
"empty_false": ProxyLogging.has_during_call_guardrails(),
|
||||
}
|
||||
monkeypatch.setattr(litellm, "callbacks", [_G()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["with_guardrail_true"] = ProxyLogging.has_during_call_guardrails()
|
||||
monkeypatch.setattr(litellm, "callbacks", [_PlainLogger()])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
snapshot["only_plain_logger_false"] = ProxyLogging.has_during_call_guardrails()
|
||||
assert snapshot == {
|
||||
"empty_false": False,
|
||||
"with_guardrail_true": True,
|
||||
"only_plain_logger_false": False,
|
||||
}
|
||||
|
||||
|
||||
def test_has_during_call_guardrails_resolution_error_raises(monkeypatch):
|
||||
monkeypatch.setattr(litellm, "callbacks", ["x"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"get_custom_logger_compatible_class",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("oops")),
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
ProxyLogging.has_during_call_guardrails()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_combined_callback_list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_combined_callback_list_matrix(proxy_logging):
|
||||
snapshot = {
|
||||
"merge_dedupes_shared": sorted(
|
||||
proxy_logging.get_combined_callback_list(
|
||||
dynamic_success_callbacks=["dyn-1", "shared"],
|
||||
global_callbacks=["glob-1", "shared"],
|
||||
)
|
||||
),
|
||||
"none_dynamic_returns_global_copy": proxy_logging.get_combined_callback_list(
|
||||
dynamic_success_callbacks=None, global_callbacks=["a", "b", "c"]
|
||||
),
|
||||
"empty_both": proxy_logging.get_combined_callback_list(
|
||||
dynamic_success_callbacks=[], global_callbacks=[]
|
||||
),
|
||||
}
|
||||
assert snapshot == {
|
||||
"merge_dedupes_shared": ["dyn-1", "glob-1", "shared"],
|
||||
"none_dynamic_returns_global_copy": ["a", "b", "c"],
|
||||
"empty_both": [],
|
||||
}
|
||||
|
||||
|
||||
def test_get_combined_callback_list_unhashable_dynamic_raises(proxy_logging):
|
||||
with pytest.raises(TypeError):
|
||||
proxy_logging.get_combined_callback_list(
|
||||
dynamic_success_callbacks=[{"unhashable": True}],
|
||||
global_callbacks=[],
|
||||
)
|
||||
@ -0,0 +1,59 @@
|
||||
"""Pin the ``_CallbackCapabilities`` dataclass shape and defaults."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.utils import _CallbackCapabilities
|
||||
|
||||
|
||||
def test_callback_capabilities_default_values():
|
||||
caps = _CallbackCapabilities()
|
||||
snapshot = {
|
||||
"has_post_call_response_headers": caps.has_post_call_response_headers,
|
||||
"has_iterator_override": caps.has_iterator_override,
|
||||
"has_streaming_chunk_override": caps.has_streaming_chunk_override,
|
||||
"has_guardrail": caps.has_guardrail,
|
||||
"has_pre_call_override": caps.has_pre_call_override,
|
||||
"iterator_overrides": caps.iterator_overrides,
|
||||
"resolved_callbacks": caps.resolved_callbacks,
|
||||
}
|
||||
assert snapshot == {
|
||||
"has_post_call_response_headers": False,
|
||||
"has_iterator_override": False,
|
||||
"has_streaming_chunk_override": False,
|
||||
"has_guardrail": False,
|
||||
"has_pre_call_override": False,
|
||||
"iterator_overrides": (),
|
||||
"resolved_callbacks": (),
|
||||
}
|
||||
|
||||
|
||||
def test_callback_capabilities_explicit_values_preserved():
|
||||
cb1 = object()
|
||||
cb2 = object()
|
||||
caps = _CallbackCapabilities(
|
||||
has_post_call_response_headers=True,
|
||||
has_iterator_override=True,
|
||||
has_streaming_chunk_override=False,
|
||||
has_guardrail=True,
|
||||
has_pre_call_override=False,
|
||||
iterator_overrides=((cb1, "override"), (cb2, "apply_guardrail")),
|
||||
resolved_callbacks=(cb1, cb2),
|
||||
)
|
||||
assert caps.has_post_call_response_headers is True
|
||||
assert caps.iterator_overrides == ((cb1, "override"), (cb2, "apply_guardrail"))
|
||||
assert caps.resolved_callbacks == (cb1, cb2)
|
||||
|
||||
|
||||
def test_callback_capabilities_is_frozen_error_on_mutation_raises():
|
||||
caps = _CallbackCapabilities()
|
||||
with pytest.raises(dataclasses.FrozenInstanceError):
|
||||
caps.has_post_call_response_headers = True # type: ignore[misc]
|
||||
|
||||
|
||||
def test_callback_capabilities_invalid_field_error_raises():
|
||||
with pytest.raises(TypeError):
|
||||
_CallbackCapabilities(unknown_field=True) # type: ignore[call-arg]
|
||||
@ -0,0 +1,86 @@
|
||||
"""Pin ``ProxyLogging.during_call_hook``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_caps_cache():
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
yield
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
|
||||
def _make_guardrail(name="g1", should_run=True, response=None):
|
||||
cb = MagicMock(spec=CustomGuardrail)
|
||||
cb.__class__ = CustomGuardrail
|
||||
cb.guardrail_name = name
|
||||
cb.event_hook = GuardrailEventHooks.during_call
|
||||
cb.use_native_during_call_hook = False
|
||||
cb.should_run_guardrail = MagicMock(return_value=should_run)
|
||||
cb.async_moderation_hook = AsyncMock(return_value=response)
|
||||
return cb
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_during_call_hook_no_guardrail_fast_path_returns_data(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
|
||||
data = {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
|
||||
out = await proxy_logging.during_call_hook(
|
||||
data=data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
assert out is data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_during_call_hook_runs_guardrails_in_parallel(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
g1 = _make_guardrail("a")
|
||||
g2 = _make_guardrail("b")
|
||||
monkeypatch.setattr(litellm, "callbacks", [g1, g2])
|
||||
data = {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
|
||||
out = await proxy_logging.during_call_hook(
|
||||
data=data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
snapshot = {
|
||||
"out_is_data": out is data,
|
||||
"a_called": g1.async_moderation_hook.called,
|
||||
"b_called": g2.async_moderation_hook.called,
|
||||
}
|
||||
assert snapshot == {"out_is_data": True, "a_called": True, "b_called": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_during_call_hook_guardrail_skipped_when_should_not_run(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
g = _make_guardrail("g", should_run=False)
|
||||
monkeypatch.setattr(litellm, "callbacks", [g])
|
||||
await proxy_logging.during_call_hook(
|
||||
data={"model": "m"},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
g.async_moderation_hook.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_during_call_hook_guardrail_error_raises(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
g = _make_guardrail("bad")
|
||||
g.async_moderation_hook = AsyncMock(side_effect=RuntimeError("blocked"))
|
||||
monkeypatch.setattr(litellm, "callbacks", [g])
|
||||
with pytest.raises(RuntimeError):
|
||||
await proxy_logging.during_call_hook(
|
||||
data={"model": "m"},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
@ -0,0 +1,559 @@
|
||||
"""Pin ProxyLogging guardrail pipeline helpers.
|
||||
|
||||
Covers ``_should_use_guardrail_load_balancing``, ``_execute_guardrail_hook``,
|
||||
``_execute_guardrail_with_load_balancing``, ``_process_guardrail_callback``,
|
||||
``_process_prompt_template``, ``_process_guardrail_metadata``,
|
||||
``_maybe_execute_pipelines``, ``_handle_pipeline_result``,
|
||||
``_run_guardrail_task_with_enrichment``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
ModifyResponseException,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_caps_cache():
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
yield
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _should_use_guardrail_load_balancing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_should_use_guardrail_load_balancing_truth_table(proxy_logging):
|
||||
snapshot = {}
|
||||
router = MagicMock()
|
||||
router.guardrail_list = [{"guardrail_name": "g1"}, {"guardrail_name": "g1"}]
|
||||
with patch("litellm.proxy.proxy_server.llm_router", router):
|
||||
snapshot["multiple_deployments"] = proxy_logging._should_use_guardrail_load_balancing("g1")
|
||||
router.guardrail_list = [{"guardrail_name": "g1"}]
|
||||
with patch("litellm.proxy.proxy_server.llm_router", router):
|
||||
snapshot["single_deployment"] = proxy_logging._should_use_guardrail_load_balancing("g1")
|
||||
with patch("litellm.proxy.proxy_server.llm_router", None):
|
||||
snapshot["no_router"] = proxy_logging._should_use_guardrail_load_balancing("g1")
|
||||
router.guardrail_list = [{"guardrail_name": "other"}, {"guardrail_name": "other"}]
|
||||
with patch("litellm.proxy.proxy_server.llm_router", router):
|
||||
snapshot["unmatched_name"] = proxy_logging._should_use_guardrail_load_balancing("g1")
|
||||
assert snapshot == {
|
||||
"multiple_deployments": True,
|
||||
"single_deployment": False,
|
||||
"no_router": False,
|
||||
"unmatched_name": False,
|
||||
}
|
||||
|
||||
|
||||
def test_should_use_guardrail_load_balancing_error_on_bad_guardrail_list(proxy_logging):
|
||||
router = MagicMock()
|
||||
router.guardrail_list = "not a list"
|
||||
with patch("litellm.proxy.proxy_server.llm_router", router):
|
||||
with pytest.raises((TypeError, AttributeError)):
|
||||
proxy_logging._should_use_guardrail_load_balancing("g1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _execute_guardrail_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_guardrail():
|
||||
cb = MagicMock(spec=CustomGuardrail)
|
||||
cb.__class__ = CustomGuardrail
|
||||
cb.guardrail_name = "g"
|
||||
cb.event_hook = GuardrailEventHooks.pre_call
|
||||
cb.use_native_during_call_hook = False
|
||||
cb.async_pre_call_hook = AsyncMock(return_value={"a": 1, "b": 2, "c": 3})
|
||||
cb.async_moderation_hook = AsyncMock(return_value={"x": 1, "y": 2, "z": 3})
|
||||
cb.async_post_call_success_hook = AsyncMock(return_value={"p": 1, "q": 2, "r": 3})
|
||||
return cb
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_guardrail_hook_pre_call(proxy_logging, make_user_api_key_auth):
|
||||
cb = _make_guardrail()
|
||||
out = await proxy_logging._execute_guardrail_hook(
|
||||
callback=cb,
|
||||
hook_type="pre_call",
|
||||
data={"model": "m"},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
assert out == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_guardrail_hook_during_call(proxy_logging, make_user_api_key_auth):
|
||||
cb = _make_guardrail()
|
||||
out = await proxy_logging._execute_guardrail_hook(
|
||||
callback=cb,
|
||||
hook_type="during_call",
|
||||
data={"model": "m"},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
assert out == {"x": 1, "y": 2, "z": 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_guardrail_hook_post_call(proxy_logging, make_user_api_key_auth):
|
||||
cb = _make_guardrail()
|
||||
out = await proxy_logging._execute_guardrail_hook(
|
||||
callback=cb,
|
||||
hook_type="post_call",
|
||||
data={"model": "m"},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
response={"original": True},
|
||||
)
|
||||
assert out == {"p": 1, "q": 2, "r": 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_guardrail_hook_unknown_hook_type_raises(proxy_logging, make_user_api_key_auth):
|
||||
cb = _make_guardrail()
|
||||
with pytest.raises(ValueError, match="Unknown hook_type"):
|
||||
await proxy_logging._execute_guardrail_hook(
|
||||
callback=cb,
|
||||
hook_type="weird", # type: ignore[arg-type]
|
||||
data={},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _execute_guardrail_with_load_balancing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_guardrail_with_load_balancing_routes_through_router(
|
||||
proxy_logging, make_user_api_key_auth
|
||||
):
|
||||
cb = _make_guardrail()
|
||||
router = MagicMock()
|
||||
router.get_available_guardrail = MagicMock(return_value={"callback": cb})
|
||||
with patch("litellm.proxy.proxy_server.llm_router", router):
|
||||
out = await proxy_logging._execute_guardrail_with_load_balancing(
|
||||
guardrail_name="g",
|
||||
hook_type="pre_call",
|
||||
data={"model": "m"},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
assert out == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_guardrail_with_load_balancing_router_none_raises(
|
||||
proxy_logging, make_user_api_key_auth
|
||||
):
|
||||
with patch("litellm.proxy.proxy_server.llm_router", None):
|
||||
with pytest.raises(ValueError, match="Router not initialized"):
|
||||
await proxy_logging._execute_guardrail_with_load_balancing(
|
||||
guardrail_name="g",
|
||||
hook_type="pre_call",
|
||||
data={},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_guardrail_with_load_balancing_no_callback_raises(
|
||||
proxy_logging, make_user_api_key_auth
|
||||
):
|
||||
router = MagicMock()
|
||||
router.get_available_guardrail = MagicMock(return_value={"callback": None})
|
||||
with patch("litellm.proxy.proxy_server.llm_router", router):
|
||||
with pytest.raises(ValueError, match="No callback found"):
|
||||
await proxy_logging._execute_guardrail_with_load_balancing(
|
||||
guardrail_name="g",
|
||||
hook_type="pre_call",
|
||||
data={},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _process_guardrail_callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_guardrail_callback_skipped_when_should_run_false(
|
||||
proxy_logging, make_user_api_key_auth
|
||||
):
|
||||
cb = _make_guardrail()
|
||||
cb.should_run_guardrail = MagicMock(return_value=False)
|
||||
out = await proxy_logging._process_guardrail_callback(
|
||||
callback=cb,
|
||||
data={"model": "m"},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
event_type=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
assert out is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_guardrail_callback_returns_data_on_success(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
cb = _make_guardrail()
|
||||
cb.should_run_guardrail = MagicMock(return_value=True)
|
||||
proxy_logging._should_use_guardrail_load_balancing = MagicMock(return_value=False)
|
||||
out = await proxy_logging._process_guardrail_callback(
|
||||
callback=cb,
|
||||
data={"model": "m", "messages": [{"role": "user"}], "temperature": 0.1},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
event_type=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
assert out == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_guardrail_callback_enriches_and_reraises_http_exception(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
cb = _make_guardrail()
|
||||
cb.should_run_guardrail = MagicMock(return_value=True)
|
||||
detail = {"error": "blocked"}
|
||||
cb.async_pre_call_hook = AsyncMock(side_effect=HTTPException(status_code=400, detail=detail))
|
||||
cb.event_hook = "pre_call"
|
||||
proxy_logging._should_use_guardrail_load_balancing = MagicMock(return_value=False)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await proxy_logging._process_guardrail_callback(
|
||||
callback=cb,
|
||||
data={"model": "m"},
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
event_type=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
assert detail["guardrail_name"] == "g"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _process_guardrail_metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_process_guardrail_metadata_calls_header_helper(proxy_logging, monkeypatch):
|
||||
calls: List[Dict[str, Any]] = []
|
||||
|
||||
def fake_add(request_data, guardrail_name):
|
||||
calls.append({"data": request_data, "name": guardrail_name})
|
||||
|
||||
from litellm.proxy.common_utils import callback_utils
|
||||
|
||||
monkeypatch.setattr(callback_utils, "add_guardrail_to_applied_guardrails_header", fake_add)
|
||||
data = {"metadata": {"guardrails": ["g1", "g2"]}}
|
||||
proxy_logging._process_guardrail_metadata(data)
|
||||
snapshot = {
|
||||
"call_count": len(calls),
|
||||
"first_name": calls[0]["name"],
|
||||
"second_name": calls[1]["name"],
|
||||
"data_passed_is_input": all(c["data"] is data for c in calls),
|
||||
}
|
||||
assert snapshot == {
|
||||
"call_count": 2,
|
||||
"first_name": "g1",
|
||||
"second_name": "g2",
|
||||
"data_passed_is_input": True,
|
||||
}
|
||||
|
||||
|
||||
def test_process_guardrail_metadata_skips_already_applied(proxy_logging, monkeypatch):
|
||||
calls: List[str] = []
|
||||
|
||||
def fake_add(request_data, guardrail_name):
|
||||
calls.append(guardrail_name)
|
||||
|
||||
from litellm.proxy.common_utils import callback_utils
|
||||
|
||||
monkeypatch.setattr(callback_utils, "add_guardrail_to_applied_guardrails_header", fake_add)
|
||||
data = {"metadata": {"guardrails": ["g1", "g2"], "applied_guardrails": ["g1"]}}
|
||||
proxy_logging._process_guardrail_metadata(data)
|
||||
assert calls == ["g2"]
|
||||
|
||||
|
||||
def test_process_guardrail_metadata_no_metadata_is_noop(proxy_logging, monkeypatch):
|
||||
from litellm.proxy.common_utils import callback_utils
|
||||
|
||||
monkeypatch.setattr(
|
||||
callback_utils,
|
||||
"add_guardrail_to_applied_guardrails_header",
|
||||
MagicMock(side_effect=AssertionError("should not be called")),
|
||||
)
|
||||
proxy_logging._process_guardrail_metadata({})
|
||||
|
||||
|
||||
def test_process_guardrail_metadata_invalid_data_raises(proxy_logging):
|
||||
with pytest.raises(AttributeError):
|
||||
proxy_logging._process_guardrail_metadata(None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _maybe_execute_pipelines
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_execute_pipelines_no_pipelines_returns_data(proxy_logging, make_user_api_key_auth):
|
||||
data = {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
|
||||
out = await proxy_logging._maybe_execute_pipelines(
|
||||
data=data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
event_hook="pre_call",
|
||||
)
|
||||
assert out == {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_execute_pipelines_skips_pipelines_with_other_mode(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
pipeline = MagicMock()
|
||||
pipeline.mode = "post_call" # not pre_call
|
||||
data = {"metadata": {"_guardrail_pipelines": [("p1", pipeline)]}, "model": "m", "messages": []}
|
||||
executed = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.policy_engine.pipeline_executor.PipelineExecutor.execute_steps", executed
|
||||
)
|
||||
out = await proxy_logging._maybe_execute_pipelines(
|
||||
data=data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
event_hook="pre_call",
|
||||
)
|
||||
executed.assert_not_called()
|
||||
assert out is data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_execute_pipelines_blocks_on_block_terminal_action_raises(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
pipeline = MagicMock()
|
||||
pipeline.mode = "pre_call"
|
||||
pipeline.steps = []
|
||||
fake_result = MagicMock()
|
||||
fake_result.terminal_action = "block"
|
||||
fake_result.step_results = []
|
||||
data = {"metadata": {"_guardrail_pipelines": [("policy-1", pipeline)]}, "messages": [], "model": "m"}
|
||||
|
||||
async def fake_execute_steps(**kwargs):
|
||||
return fake_result
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.policy_engine.pipeline_executor.PipelineExecutor.execute_steps",
|
||||
fake_execute_steps,
|
||||
)
|
||||
with pytest.raises(HTTPException):
|
||||
await proxy_logging._maybe_execute_pipelines(
|
||||
data=data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
call_type="completion",
|
||||
event_hook="pre_call",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_pipeline_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_handle_pipeline_result_allow_with_modifications():
|
||||
data = {"a": 1}
|
||||
result = MagicMock()
|
||||
result.terminal_action = "allow"
|
||||
result.modified_data = {"b": 2, "c": 3}
|
||||
out = ProxyLogging._handle_pipeline_result(result=result, data=data, policy_name="p")
|
||||
assert out == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def test_handle_pipeline_result_block_raises_http_exception():
|
||||
result = MagicMock()
|
||||
result.terminal_action = "block"
|
||||
result.step_results = []
|
||||
with pytest.raises(HTTPException) as info:
|
||||
ProxyLogging._handle_pipeline_result(result=result, data={"model": "m"}, policy_name="p")
|
||||
detail = info.value.detail
|
||||
snapshot = {
|
||||
"is_dict": isinstance(detail, dict),
|
||||
"error_type": detail["error"]["type"],
|
||||
"policy": detail["error"]["pipeline_context"]["policy"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"is_dict": True,
|
||||
"error_type": "guardrail_pipeline_error",
|
||||
"policy": "p",
|
||||
}
|
||||
|
||||
|
||||
def test_handle_pipeline_result_modify_response_raises_modify_exception():
|
||||
result = MagicMock()
|
||||
result.terminal_action = "modify_response"
|
||||
result.modify_response_message = "filtered"
|
||||
with pytest.raises(ModifyResponseException):
|
||||
ProxyLogging._handle_pipeline_result(result=result, data={"model": "m"}, policy_name="p")
|
||||
|
||||
|
||||
def test_handle_pipeline_result_unknown_action_returns_data():
|
||||
data = {"a": 1, "b": 2, "c": 3}
|
||||
result = MagicMock()
|
||||
result.terminal_action = "something_else"
|
||||
assert ProxyLogging._handle_pipeline_result(result=result, data=data, policy_name="p") is data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_guardrail_task_with_enrichment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_guardrail_task_with_enrichment_passes_result():
|
||||
async def task():
|
||||
return {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
out = await ProxyLogging._run_guardrail_task_with_enrichment(
|
||||
callback=MagicMock(guardrail_name="g"), coro=task()
|
||||
)
|
||||
assert out == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_guardrail_task_with_enrichment_enriches_http_exception_raises():
|
||||
detail = {"error": "blocked"}
|
||||
|
||||
async def task():
|
||||
raise HTTPException(status_code=400, detail=detail)
|
||||
|
||||
cb = MagicMock()
|
||||
cb.guardrail_name = "presidio"
|
||||
cb.event_hook = "pre_call"
|
||||
with pytest.raises(HTTPException):
|
||||
await ProxyLogging._run_guardrail_task_with_enrichment(callback=cb, coro=task())
|
||||
assert detail["guardrail_name"] == "presidio"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _process_prompt_template
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_prompt_template_no_op_when_no_prompt_spec(proxy_logging, monkeypatch):
|
||||
from litellm.proxy.prompts import prompt_registry
|
||||
|
||||
monkeypatch.setattr(
|
||||
prompt_registry.IN_MEMORY_PROMPT_REGISTRY, "get_prompt_callback_by_id", lambda *a, **kw: None
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
prompt_registry.IN_MEMORY_PROMPT_REGISTRY, "get_prompt_by_id", lambda *a, **kw: None
|
||||
)
|
||||
data: Dict[str, Any] = {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
|
||||
await proxy_logging._process_prompt_template(
|
||||
data=data,
|
||||
litellm_logging_obj=MagicMock(),
|
||||
prompt_id="some-id",
|
||||
prompt_version=1,
|
||||
call_type="completion",
|
||||
)
|
||||
assert data == {"messages": [{"role": "user"}], "model": "m", "temperature": 0.1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_prompt_template_applies_when_spec_resolves(proxy_logging, monkeypatch):
|
||||
from litellm.proxy.prompts import prompt_registry
|
||||
|
||||
custom_logger = MagicMock()
|
||||
prompt_spec = MagicMock()
|
||||
prompt_spec.litellm_params = MagicMock(prompt_id="resolved-id")
|
||||
|
||||
monkeypatch.setattr(
|
||||
prompt_registry.IN_MEMORY_PROMPT_REGISTRY,
|
||||
"get_prompt_callback_by_id",
|
||||
lambda *a, **kw: custom_logger,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
prompt_registry.IN_MEMORY_PROMPT_REGISTRY, "get_prompt_by_id", lambda *a, **kw: prompt_spec
|
||||
)
|
||||
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.async_get_chat_completion_prompt = AsyncMock(
|
||||
return_value=(
|
||||
"model-out",
|
||||
[{"role": "user", "content": "rendered"}],
|
||||
{"temperature": 0.5, "top_p": 1},
|
||||
)
|
||||
)
|
||||
data: Dict[str, Any] = {
|
||||
"messages": [{"role": "user", "content": "orig"}],
|
||||
"model": "m",
|
||||
"prompt_id": "x",
|
||||
}
|
||||
await proxy_logging._process_prompt_template(
|
||||
data=data,
|
||||
litellm_logging_obj=logging_obj,
|
||||
prompt_id="x",
|
||||
prompt_version=None,
|
||||
call_type="completion",
|
||||
)
|
||||
snapshot = {
|
||||
"model": data["model"],
|
||||
"messages": data["messages"],
|
||||
"temperature": data["temperature"],
|
||||
"top_p": data["top_p"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"model": "model-out",
|
||||
"messages": [{"role": "user", "content": "rendered"}],
|
||||
"temperature": 0.5,
|
||||
"top_p": 1,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_prompt_template_async_get_prompt_error_raises(proxy_logging, monkeypatch):
|
||||
from litellm.proxy.prompts import prompt_registry
|
||||
|
||||
custom_logger = MagicMock()
|
||||
prompt_spec = MagicMock()
|
||||
prompt_spec.litellm_params = MagicMock(prompt_id="x")
|
||||
monkeypatch.setattr(
|
||||
prompt_registry.IN_MEMORY_PROMPT_REGISTRY,
|
||||
"get_prompt_callback_by_id",
|
||||
lambda *a, **kw: custom_logger,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
prompt_registry.IN_MEMORY_PROMPT_REGISTRY, "get_prompt_by_id", lambda *a, **kw: prompt_spec
|
||||
)
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.async_get_chat_completion_prompt = AsyncMock(side_effect=RuntimeError("bad prompt"))
|
||||
with pytest.raises(RuntimeError):
|
||||
await proxy_logging._process_prompt_template(
|
||||
data={"messages": [], "model": "m", "prompt_id": "x"},
|
||||
litellm_logging_obj=logging_obj,
|
||||
prompt_id="x",
|
||||
prompt_version=None,
|
||||
call_type="completion",
|
||||
)
|
||||
@ -0,0 +1,186 @@
|
||||
"""Pin behavior of ``InternalUsageCache``: a thin adapter over ``DualCache``.
|
||||
|
||||
Each method should pass-through to the underlying ``DualCache`` with
|
||||
exactly the same arguments, mapping ``litellm_parent_otel_span`` to the
|
||||
``DualCache`` kw it expects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy.utils import InternalUsageCache
|
||||
|
||||
|
||||
def _kwargs_snapshot(call):
|
||||
return dict(call.kwargs)
|
||||
|
||||
|
||||
def test_internal_usage_cache_init_stores_dual_cache():
|
||||
inner = DualCache(default_in_memory_ttl=1)
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
snapshot = {
|
||||
"is_internal_usage_cache": isinstance(cache, InternalUsageCache),
|
||||
"dual_cache_is_inner": cache.dual_cache is inner,
|
||||
"ttl_is_one": inner.default_in_memory_ttl == 1,
|
||||
}
|
||||
assert snapshot == {
|
||||
"is_internal_usage_cache": True,
|
||||
"dual_cache_is_inner": True,
|
||||
"ttl_is_one": True,
|
||||
}
|
||||
|
||||
|
||||
def test_internal_usage_cache_init_error_requires_dual_cache():
|
||||
with pytest.raises(TypeError):
|
||||
InternalUsageCache() # type: ignore[call-arg]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_get_cache_forwards_args():
|
||||
inner = MagicMock()
|
||||
inner.async_get_cache = AsyncMock(return_value={"hit": True, "value": 42, "source": "redis"})
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
|
||||
result = await cache.async_get_cache(key="k", litellm_parent_otel_span="span", local_only=True, extra="x")
|
||||
forwarded = _kwargs_snapshot(inner.async_get_cache.call_args)
|
||||
assert forwarded == {"key": "k", "local_only": True, "parent_otel_span": "span", "extra": "x"}
|
||||
assert result == {"hit": True, "value": 42, "source": "redis"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_get_cache_propagates_underlying_error_raises():
|
||||
inner = MagicMock()
|
||||
inner.async_get_cache = AsyncMock(side_effect=RuntimeError("redis down"))
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
with pytest.raises(RuntimeError, match="redis down"):
|
||||
await cache.async_get_cache(key="k", litellm_parent_otel_span=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_set_cache_forwards_args():
|
||||
inner = MagicMock()
|
||||
inner.async_set_cache = AsyncMock()
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
|
||||
await cache.async_set_cache(key="k", value="v", litellm_parent_otel_span="span", local_only=False, ttl=60)
|
||||
forwarded = _kwargs_snapshot(inner.async_set_cache.call_args)
|
||||
assert forwarded == {
|
||||
"key": "k",
|
||||
"value": "v",
|
||||
"local_only": False,
|
||||
"litellm_parent_otel_span": "span",
|
||||
"ttl": 60,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_set_cache_propagates_error_raises():
|
||||
inner = MagicMock()
|
||||
inner.async_set_cache = AsyncMock(side_effect=ValueError("bad value"))
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
with pytest.raises(ValueError, match="bad value"):
|
||||
await cache.async_set_cache(key="k", value="v", litellm_parent_otel_span=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_batch_set_cache_forwards_pipeline():
|
||||
inner = MagicMock()
|
||||
inner.async_set_cache_pipeline = AsyncMock()
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
|
||||
pairs = [("a", 1), ("b", 2)]
|
||||
await cache.async_batch_set_cache(cache_list=pairs, litellm_parent_otel_span=None, local_only=True, ttl=10)
|
||||
forwarded = _kwargs_snapshot(inner.async_set_cache_pipeline.call_args)
|
||||
assert forwarded == {
|
||||
"cache_list": pairs,
|
||||
"local_only": True,
|
||||
"litellm_parent_otel_span": None,
|
||||
"ttl": 10,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_batch_set_cache_propagates_error_raises():
|
||||
inner = MagicMock()
|
||||
inner.async_set_cache_pipeline = AsyncMock(side_effect=ConnectionError("network"))
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
with pytest.raises(ConnectionError):
|
||||
await cache.async_batch_set_cache(cache_list=[], litellm_parent_otel_span=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_batch_get_cache_forwards_args():
|
||||
inner = MagicMock()
|
||||
inner.async_batch_get_cache = AsyncMock(return_value=[1, 2, 3])
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
result = await cache.async_batch_get_cache(keys=["a", "b", "c"], parent_otel_span="span", local_only=False)
|
||||
forwarded = _kwargs_snapshot(inner.async_batch_get_cache.call_args)
|
||||
assert forwarded == {"keys": ["a", "b", "c"], "parent_otel_span": "span", "local_only": False}
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_batch_get_cache_invalid_input_raises():
|
||||
inner = MagicMock()
|
||||
inner.async_batch_get_cache = AsyncMock(side_effect=TypeError("not a list"))
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
with pytest.raises(TypeError):
|
||||
await cache.async_batch_get_cache(keys=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_increment_cache_forwards_args():
|
||||
inner = MagicMock()
|
||||
inner.async_increment_cache = AsyncMock(return_value=5.0)
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
result = await cache.async_increment_cache(key="counter", value=1.5, litellm_parent_otel_span="span")
|
||||
forwarded = _kwargs_snapshot(inner.async_increment_cache.call_args)
|
||||
assert forwarded == {"key": "counter", "value": 1.5, "local_only": False, "parent_otel_span": "span"}
|
||||
assert result == 5.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_increment_cache_propagates_error_raises():
|
||||
inner = MagicMock()
|
||||
inner.async_increment_cache = AsyncMock(side_effect=OverflowError())
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
with pytest.raises(OverflowError):
|
||||
await cache.async_increment_cache(key="x", value=1.0, litellm_parent_otel_span=None)
|
||||
|
||||
|
||||
def test_set_cache_forwards_args():
|
||||
inner = MagicMock()
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
cache.set_cache(key="k", value="v", local_only=True, ttl=30)
|
||||
forwarded = _kwargs_snapshot(inner.set_cache.call_args)
|
||||
assert forwarded == {"key": "k", "value": "v", "local_only": True, "ttl": 30}
|
||||
|
||||
|
||||
def test_set_cache_propagates_error_raises():
|
||||
inner = MagicMock()
|
||||
inner.set_cache = MagicMock(side_effect=RuntimeError("no redis"))
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
with pytest.raises(RuntimeError):
|
||||
cache.set_cache(key="k", value="v")
|
||||
|
||||
|
||||
def test_get_cache_forwards_args_and_returns_inner_result():
|
||||
inner = MagicMock()
|
||||
inner.get_cache = MagicMock(return_value={"k": "v", "ttl": 60, "source": "mem"})
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
result = cache.get_cache(key="k", local_only=False)
|
||||
forwarded = _kwargs_snapshot(inner.get_cache.call_args)
|
||||
assert forwarded == {"key": "k", "local_only": False}
|
||||
assert result == {"k": "v", "ttl": 60, "source": "mem"}
|
||||
|
||||
|
||||
def test_get_cache_propagates_error_raises():
|
||||
inner = MagicMock()
|
||||
inner.get_cache = MagicMock(side_effect=KeyError("missing"))
|
||||
cache = InternalUsageCache(dual_cache=inner)
|
||||
with pytest.raises(KeyError):
|
||||
cache.get_cache(key="missing")
|
||||
403
tests/test_litellm/proxy/utils/proxy_logging/test_lifecycle.py
Normal file
403
tests/test_litellm/proxy/utils/proxy_logging/test_lifecycle.py
Normal file
@ -0,0 +1,403 @@
|
||||
"""Pin ProxyLogging lifecycle: ``__init__``, ``startup_event``,
|
||||
``update_values``, ``_add_proxy_hooks``, ``get_proxy_hook``, and
|
||||
``_init_litellm_callbacks``.
|
||||
|
||||
Also covers ``update_request_status`` and ``_convert_user_api_key_auth_to_dict``
|
||||
because they are direct dependents on the lifecycle state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.utils import (
|
||||
InternalUsageCache,
|
||||
ProxyLogging,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# __init__
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_proxy_logging_init_sets_default_state(mock_callbacks_disabled):
|
||||
cache = UserApiKeyCache()
|
||||
pl = ProxyLogging(user_api_key_cache=cache)
|
||||
snapshot = {
|
||||
"internal_usage_cache_type": type(pl.internal_usage_cache).__name__,
|
||||
"alerting_is_none": pl.alerting is None,
|
||||
"alerting_threshold": pl.alerting_threshold,
|
||||
"premium_user": pl.premium_user,
|
||||
"proxy_hook_mapping": pl.proxy_hook_mapping,
|
||||
"daily_report_started": pl.daily_report_started,
|
||||
"hanging_requests_check_started": pl.hanging_requests_check_started,
|
||||
}
|
||||
assert snapshot == {
|
||||
"internal_usage_cache_type": "InternalUsageCache",
|
||||
"alerting_is_none": True,
|
||||
"alerting_threshold": 300,
|
||||
"premium_user": False,
|
||||
"proxy_hook_mapping": {},
|
||||
"daily_report_started": False,
|
||||
"hanging_requests_check_started": False,
|
||||
}
|
||||
|
||||
|
||||
def test_proxy_logging_init_premium_user_flag(mock_callbacks_disabled):
|
||||
pl = ProxyLogging(user_api_key_cache=UserApiKeyCache(), premium_user=True)
|
||||
assert pl.premium_user is True
|
||||
|
||||
|
||||
def test_proxy_logging_init_missing_cache_raises():
|
||||
with pytest.raises(TypeError):
|
||||
ProxyLogging() # type: ignore[call-arg]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_values
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_values_stores_alerting_state(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = MagicMock()
|
||||
proxy_logging.update_values(
|
||||
alerting=["slack"],
|
||||
alerting_threshold=42.0,
|
||||
alert_types=["llm_too_slow"],
|
||||
alert_to_webhook_url={"key": "value"},
|
||||
)
|
||||
snapshot = {
|
||||
"alerting": proxy_logging.alerting,
|
||||
"threshold": proxy_logging.alerting_threshold,
|
||||
"alert_types": proxy_logging.alert_types,
|
||||
"webhook_url": proxy_logging.alert_to_webhook_url,
|
||||
}
|
||||
assert snapshot == {
|
||||
"alerting": ["slack"],
|
||||
"threshold": 42.0,
|
||||
"alert_types": ["llm_too_slow"],
|
||||
"webhook_url": {"key": "value"},
|
||||
}
|
||||
|
||||
|
||||
def test_update_values_with_only_redis_cache_does_not_touch_slack(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = MagicMock()
|
||||
redis = MagicMock()
|
||||
proxy_logging.update_values(redis_cache=redis)
|
||||
proxy_logging.slack_alerting_instance.update_values.assert_not_called()
|
||||
assert proxy_logging.internal_usage_cache.dual_cache.redis_cache is redis
|
||||
|
||||
|
||||
def test_update_values_with_no_args_is_noop(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = MagicMock()
|
||||
proxy_logging.update_values()
|
||||
proxy_logging.slack_alerting_instance.update_values.assert_not_called()
|
||||
|
||||
|
||||
def test_update_values_invalid_type_for_alerting_raises(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = MagicMock(
|
||||
update_values=MagicMock(side_effect=TypeError("bad type"))
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
proxy_logging.update_values(alerting={"not": "a list"}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# startup_event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_startup_event_initializes_slack_and_callbacks(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = MagicMock()
|
||||
proxy_logging.slack_alerting_instance.alert_types = []
|
||||
proxy_logging._init_litellm_callbacks = MagicMock()
|
||||
proxy_logging.update_values = MagicMock()
|
||||
|
||||
proxy_logging.startup_event(llm_router=None, redis_usage_cache=None)
|
||||
snapshot = {
|
||||
"update_called": proxy_logging.update_values.called,
|
||||
"init_called": proxy_logging._init_litellm_callbacks.called,
|
||||
"slack_update_called": proxy_logging.slack_alerting_instance.update_values.called,
|
||||
}
|
||||
assert snapshot == {
|
||||
"update_called": True,
|
||||
"init_called": True,
|
||||
"slack_update_called": True,
|
||||
}
|
||||
|
||||
|
||||
def test_startup_event_propagates_init_callbacks_failure_raises(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = MagicMock()
|
||||
proxy_logging.slack_alerting_instance.alert_types = []
|
||||
proxy_logging._init_litellm_callbacks = MagicMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
proxy_logging.startup_event(llm_router=None, redis_usage_cache=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _add_proxy_hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_add_proxy_hooks_registers_callbacks(proxy_logging, monkeypatch):
|
||||
"""Patch ``PROXY_HOOKS`` and the resolver so we control exactly
|
||||
what gets registered. Verifies that the resulting instances land in
|
||||
``proxy_logging.proxy_hook_mapping`` keyed by hook name.
|
||||
"""
|
||||
hook_keys = ["cache_control_check", "max_budget_limiter"]
|
||||
registered: List[Any] = []
|
||||
|
||||
from litellm.proxy import utils as utils_mod
|
||||
|
||||
def fake_get_proxy_hook(hook_name):
|
||||
class _Stub:
|
||||
__name__ = hook_name
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.hook_name = hook_name
|
||||
|
||||
return _Stub
|
||||
|
||||
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", hook_keys)
|
||||
monkeypatch.setattr(utils_mod, "get_proxy_hook", fake_get_proxy_hook)
|
||||
monkeypatch.setattr(
|
||||
litellm.logging_callback_manager,
|
||||
"add_litellm_callback",
|
||||
lambda cb: registered.append(cb),
|
||||
)
|
||||
|
||||
with patch("litellm.proxy.proxy_server.prisma_client", None):
|
||||
proxy_logging._add_proxy_hooks(llm_router=None)
|
||||
|
||||
keys = list(proxy_logging.proxy_hook_mapping.keys())
|
||||
snapshot = {
|
||||
"mapping_keys": keys,
|
||||
"registered_count": len(registered),
|
||||
"registered_hook_names": [getattr(r, "hook_name", None) for r in registered],
|
||||
}
|
||||
assert snapshot == {
|
||||
"mapping_keys": hook_keys,
|
||||
"registered_count": len(hook_keys),
|
||||
"registered_hook_names": hook_keys,
|
||||
}
|
||||
|
||||
|
||||
def test_add_proxy_hooks_unknown_hook_raises(proxy_logging, monkeypatch):
|
||||
from litellm.proxy import utils as utils_mod
|
||||
|
||||
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", ["bogus_hook"])
|
||||
|
||||
def bad_resolver(name):
|
||||
raise KeyError(name)
|
||||
|
||||
monkeypatch.setattr(utils_mod, "get_proxy_hook", bad_resolver)
|
||||
with pytest.raises(KeyError):
|
||||
proxy_logging._add_proxy_hooks(llm_router=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_proxy_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_proxy_hook_returns_registered_instance(proxy_logging):
|
||||
s_cache = MagicMock()
|
||||
s_budget = MagicMock()
|
||||
s_parallel = MagicMock()
|
||||
proxy_logging.proxy_hook_mapping = {
|
||||
"cache_control_check": s_cache,
|
||||
"max_budget_limiter": s_budget,
|
||||
"max_parallel_request_limiter": s_parallel,
|
||||
}
|
||||
snapshot = {
|
||||
"cache_control_check": proxy_logging.get_proxy_hook("cache_control_check") is s_cache,
|
||||
"max_budget_limiter": proxy_logging.get_proxy_hook("max_budget_limiter") is s_budget,
|
||||
"max_parallel_request_limiter": proxy_logging.get_proxy_hook("max_parallel_request_limiter") is s_parallel,
|
||||
"unknown_returns_none": proxy_logging.get_proxy_hook("unknown") is None,
|
||||
}
|
||||
assert snapshot == {
|
||||
"cache_control_check": True,
|
||||
"max_budget_limiter": True,
|
||||
"max_parallel_request_limiter": True,
|
||||
"unknown_returns_none": True,
|
||||
}
|
||||
|
||||
|
||||
def test_get_proxy_hook_unknown_returns_none(proxy_logging):
|
||||
proxy_logging.proxy_hook_mapping = {}
|
||||
assert proxy_logging.get_proxy_hook("does-not-exist") is None
|
||||
|
||||
|
||||
def test_get_proxy_hook_non_string_key_raises(proxy_logging):
|
||||
# ``dict.get`` doesn't raise on unhashable types — but ``None`` returns None.
|
||||
# The pin: passing an unhashable key blows up like dict access does.
|
||||
proxy_logging.proxy_hook_mapping = {"k": object()}
|
||||
with pytest.raises(TypeError):
|
||||
proxy_logging.get_proxy_hook({"unhashable": True}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _init_litellm_callbacks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_init_litellm_callbacks_replaces_string_with_instance(proxy_logging, monkeypatch):
|
||||
from litellm.proxy import utils as utils_mod
|
||||
|
||||
sentinel_instance = MagicMock(spec=litellm.integrations.custom_logger.CustomLogger)
|
||||
sentinel_instance.__class__ = litellm.integrations.custom_logger.CustomLogger
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", ["some-string-logger"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"_init_custom_logger_compatible_class",
|
||||
lambda *a, **kw: sentinel_instance,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", [])
|
||||
proxy_logging._init_litellm_callbacks(llm_router=None)
|
||||
snapshot = {
|
||||
"replaced_first_item": litellm.callbacks[0] is sentinel_instance,
|
||||
"callbacks_grew_with_service": len(litellm.callbacks) >= 2,
|
||||
"service_logging_appended": any(
|
||||
"ServiceLogging" in type(c).__name__ for c in litellm.callbacks
|
||||
),
|
||||
}
|
||||
assert snapshot == {
|
||||
"replaced_first_item": True,
|
||||
"callbacks_grew_with_service": True,
|
||||
"service_logging_appended": True,
|
||||
}
|
||||
|
||||
|
||||
def test_init_litellm_callbacks_string_resolution_failure_keeps_string(proxy_logging, monkeypatch):
|
||||
from litellm.proxy import utils as utils_mod
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", ["unknown-logger"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"_init_custom_logger_compatible_class",
|
||||
lambda *a, **kw: None,
|
||||
)
|
||||
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", [])
|
||||
proxy_logging._init_litellm_callbacks(llm_router=None)
|
||||
# Resolver returned None — original string remains in place at idx 0.
|
||||
assert litellm.callbacks[0] == "unknown-logger"
|
||||
|
||||
|
||||
def test_init_litellm_callbacks_propagates_resolver_error_raises(proxy_logging, monkeypatch):
|
||||
from litellm.proxy import utils as utils_mod
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", ["raises-on-init"])
|
||||
monkeypatch.setattr(
|
||||
litellm.litellm_core_utils.litellm_logging,
|
||||
"_init_custom_logger_compatible_class",
|
||||
MagicMock(side_effect=RuntimeError("bad init")),
|
||||
)
|
||||
monkeypatch.setattr(utils_mod, "PROXY_HOOKS", [])
|
||||
with pytest.raises(RuntimeError):
|
||||
proxy_logging._init_litellm_callbacks(llm_router=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_request_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_request_status_when_alerting_set_writes_cache(proxy_logging):
|
||||
proxy_logging.alerting = ["slack"]
|
||||
proxy_logging.alerting_threshold = 5.0
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_set_cache(**kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
proxy_logging.internal_usage_cache.async_set_cache = fake_set_cache # type: ignore[assignment]
|
||||
await proxy_logging.update_request_status(litellm_call_id="call-1", status="success")
|
||||
snapshot = {
|
||||
"key": captured["key"],
|
||||
"value": captured["value"],
|
||||
"local_only": captured["local_only"],
|
||||
"ttl": captured["ttl"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"key": "request_status:call-1",
|
||||
"value": "success",
|
||||
"local_only": True,
|
||||
"ttl": 105.0,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_request_status_no_alerting_skips_cache(proxy_logging):
|
||||
proxy_logging.alerting = None
|
||||
proxy_logging.internal_usage_cache.async_set_cache = AsyncMock()
|
||||
await proxy_logging.update_request_status(litellm_call_id="call-1", status="success")
|
||||
proxy_logging.internal_usage_cache.async_set_cache.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_request_status_cache_error_raises(proxy_logging):
|
||||
proxy_logging.alerting = ["slack"]
|
||||
proxy_logging.internal_usage_cache.async_set_cache = AsyncMock(side_effect=ConnectionError("redis"))
|
||||
with pytest.raises(ConnectionError):
|
||||
await proxy_logging.update_request_status(litellm_call_id="x", status="fail")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _convert_user_api_key_auth_to_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_user_api_key_auth_to_dict_pydantic_uses_model_dump(proxy_logging, make_user_api_key_auth):
|
||||
auth = make_user_api_key_auth(user_id="u-1", team_id="t-1")
|
||||
result = proxy_logging._convert_user_api_key_auth_to_dict(auth)
|
||||
snapshot = {
|
||||
"user_id": result["user_id"],
|
||||
"team_id": result["team_id"],
|
||||
"is_dict": isinstance(result, dict),
|
||||
}
|
||||
assert snapshot == {"user_id": "u-1", "team_id": "t-1", "is_dict": True}
|
||||
|
||||
|
||||
def test_convert_user_api_key_auth_to_dict_plain_object_uses_dict(proxy_logging):
|
||||
class Obj:
|
||||
pass
|
||||
|
||||
obj = Obj()
|
||||
obj.a = 1
|
||||
obj.b = 2
|
||||
obj.c = 3
|
||||
result = proxy_logging._convert_user_api_key_auth_to_dict(obj)
|
||||
assert result == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def test_convert_user_api_key_auth_to_dict_none_returns_empty_dict(proxy_logging):
|
||||
assert proxy_logging._convert_user_api_key_auth_to_dict(None) == {}
|
||||
|
||||
|
||||
def test_convert_user_api_key_auth_to_dict_unconvertible_object_returns_empty(proxy_logging):
|
||||
class NoDict:
|
||||
__slots__ = ()
|
||||
|
||||
assert proxy_logging._convert_user_api_key_auth_to_dict(NoDict()) == {}
|
||||
|
||||
|
||||
def test_convert_user_api_key_auth_to_dict_pydantic_error_raises(proxy_logging):
|
||||
"""A ``model_dump`` that raises propagates."""
|
||||
|
||||
class _Boom:
|
||||
def model_dump(self):
|
||||
raise RuntimeError("model_dump failure")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
proxy_logging._convert_user_api_key_auth_to_dict(_Boom())
|
||||
@ -0,0 +1,426 @@
|
||||
"""Pin ProxyLogging's MCP-LLM bridging helpers.
|
||||
|
||||
Covers:
|
||||
- ``_convert_mcp_to_llm_format``
|
||||
- ``_convert_llm_result_to_mcp_response``
|
||||
- ``_extract_modified_arguments_from_content``
|
||||
- ``_parse_arguments_manually``
|
||||
- ``_convert_llm_result_to_mcp_during_response``
|
||||
- ``_parse_pre_mcp_call_hook_response``
|
||||
- ``_create_mcp_request_object_from_kwargs``
|
||||
- ``_convert_mcp_hook_response_to_kwargs``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.types.mcp import (
|
||||
MCPDuringCallResponseObject,
|
||||
MCPPreCallRequestObject,
|
||||
MCPPreCallResponseObject,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _convert_mcp_to_llm_format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_mcp_to_llm_format_returns_synthetic_data(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(tool_name="search", arguments={"q": "hello"})
|
||||
out = proxy_logging._convert_mcp_to_llm_format(
|
||||
request_obj=req,
|
||||
kwargs={
|
||||
"model": "gpt-4o-mini",
|
||||
"user_api_key_user_id": "u-1",
|
||||
"user_api_key_team_id": "t-1",
|
||||
"user_api_key_end_user_id": "eu-1",
|
||||
"user_api_key_hash": "hash",
|
||||
"user_api_key_request_route": "/mcp",
|
||||
"incoming_bearer_token": "tok",
|
||||
},
|
||||
)
|
||||
snapshot = {
|
||||
"model": out["model"],
|
||||
"user_id": out["user_api_key_user_id"],
|
||||
"mcp_tool_name": out["mcp_tool_name"],
|
||||
"mcp_arguments": out["mcp_arguments"],
|
||||
"incoming_bearer_token": out["incoming_bearer_token"],
|
||||
"message_role": out["messages"][0]["role"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"model": "gpt-4o-mini",
|
||||
"user_id": "u-1",
|
||||
"mcp_tool_name": "search",
|
||||
"mcp_arguments": {"q": "hello"},
|
||||
"incoming_bearer_token": "tok",
|
||||
"message_role": "user",
|
||||
}
|
||||
|
||||
|
||||
def test_convert_mcp_to_llm_format_defaults_model(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj()
|
||||
out = proxy_logging._convert_mcp_to_llm_format(request_obj=req, kwargs={})
|
||||
snapshot = {
|
||||
"model": out["model"],
|
||||
"mcp_tool_name": out["mcp_tool_name"],
|
||||
"incoming_bearer_token": out["incoming_bearer_token"],
|
||||
"user_id": out["user_api_key_user_id"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"model": "mcp-tool-call",
|
||||
"mcp_tool_name": "calculator",
|
||||
"incoming_bearer_token": None,
|
||||
"user_id": None,
|
||||
}
|
||||
|
||||
|
||||
def test_convert_mcp_to_llm_format_missing_request_obj_raises(proxy_logging):
|
||||
with pytest.raises(AttributeError):
|
||||
proxy_logging._convert_mcp_to_llm_format(request_obj=None, kwargs={})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _convert_llm_result_to_mcp_response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_response_exception_blocks(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj()
|
||||
result = proxy_logging._convert_llm_result_to_mcp_response(
|
||||
llm_result=ValueError("boom"),
|
||||
request_obj=req,
|
||||
)
|
||||
assert isinstance(result, MCPPreCallResponseObject)
|
||||
snapshot = {
|
||||
"should_proceed": result.should_proceed,
|
||||
"error_message": result.error_message,
|
||||
"modified_arguments": result.modified_arguments,
|
||||
}
|
||||
assert snapshot == {"should_proceed": False, "error_message": "boom", "modified_arguments": None}
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_response_blocked_content(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(tool_name="t", arguments={"a": 1})
|
||||
llm_result = {"messages": [{"content": "this is blocked"}]}
|
||||
result = proxy_logging._convert_llm_result_to_mcp_response(llm_result=llm_result, request_obj=req)
|
||||
assert isinstance(result, MCPPreCallResponseObject)
|
||||
assert result.should_proceed is False
|
||||
assert "blocked" in (result.error_message or "").lower()
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_response_modified_content_redacted(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(tool_name="search", arguments={"q": "ssn 123"})
|
||||
llm_result = {"messages": [{"content": "Tool: search\nArguments: {\"q\": \"[REDACTED]\"}"}]}
|
||||
result = proxy_logging._convert_llm_result_to_mcp_response(llm_result=llm_result, request_obj=req)
|
||||
assert isinstance(result, MCPPreCallResponseObject)
|
||||
snapshot = {
|
||||
"should_proceed": result.should_proceed,
|
||||
"modified_q": (result.modified_arguments or {}).get("q"),
|
||||
"error": result.error_message,
|
||||
}
|
||||
assert snapshot == {"should_proceed": True, "modified_q": "[REDACTED]", "error": None}
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_response_string_blocks(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj()
|
||||
result = proxy_logging._convert_llm_result_to_mcp_response(llm_result="bad input", request_obj=req)
|
||||
assert isinstance(result, MCPPreCallResponseObject)
|
||||
snapshot = {
|
||||
"should_proceed": result.should_proceed,
|
||||
"error_message": result.error_message,
|
||||
"modified_arguments": result.modified_arguments,
|
||||
}
|
||||
assert snapshot == {"should_proceed": False, "error_message": "bad input", "modified_arguments": None}
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_response_unmodified_returns_none(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(tool_name="x", arguments={"a": 1})
|
||||
same_content = "Tool: x\nArguments: {'a': 1}"
|
||||
result = proxy_logging._convert_llm_result_to_mcp_response(
|
||||
llm_result={"messages": [{"content": same_content}]},
|
||||
request_obj=req,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_response_no_request_obj_raises(proxy_logging):
|
||||
with pytest.raises(AttributeError):
|
||||
proxy_logging._convert_llm_result_to_mcp_response(llm_result={"messages": [{"content": "x"}]}, request_obj=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_modified_arguments_from_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_extract_modified_arguments_from_content_parses_json(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj()
|
||||
out = proxy_logging._extract_modified_arguments_from_content(
|
||||
masked_content="Tool: x\nArguments: {\"a\": 1, \"b\": 2, \"c\": 3}",
|
||||
request_obj=req,
|
||||
)
|
||||
assert out == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def test_extract_modified_arguments_from_content_no_arguments_line_returns_none(proxy_logging, make_mcp_request_obj):
|
||||
out = proxy_logging._extract_modified_arguments_from_content(
|
||||
masked_content="random content with no arguments",
|
||||
request_obj=make_mcp_request_obj(),
|
||||
)
|
||||
assert out is None
|
||||
|
||||
|
||||
def test_extract_modified_arguments_from_content_empty_string_returns_none(proxy_logging, make_mcp_request_obj):
|
||||
out = proxy_logging._extract_modified_arguments_from_content(
|
||||
masked_content="",
|
||||
request_obj=make_mcp_request_obj(),
|
||||
)
|
||||
assert out is None
|
||||
|
||||
|
||||
def test_extract_modified_arguments_from_content_invalid_json_falls_back(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(arguments={"name": "alice"})
|
||||
out = proxy_logging._extract_modified_arguments_from_content(
|
||||
masked_content="Tool: x\nArguments: {name: REDACTED}",
|
||||
request_obj=req,
|
||||
)
|
||||
assert isinstance(out, dict)
|
||||
assert "name" in out
|
||||
|
||||
|
||||
def test_extract_modified_arguments_from_content_error_swallowed_returns_none(proxy_logging):
|
||||
"""Internal try/except swallows any unexpected error and returns None."""
|
||||
out = proxy_logging._extract_modified_arguments_from_content(masked_content=None, request_obj=None)
|
||||
assert out is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_arguments_manually
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_arguments_manually_applies_overrides(proxy_logging):
|
||||
original = {"name": "alice", "ssn": "123-45-6789"}
|
||||
out = proxy_logging._parse_arguments_manually(
|
||||
args_text='"name": "[REDACTED]", "ssn": "[REDACTED]"',
|
||||
original_args=original,
|
||||
)
|
||||
snapshot = {"name": out["name"], "ssn": out["ssn"], "original_unchanged": original["name"]}
|
||||
assert snapshot == {"name": "[REDACTED]", "ssn": "[REDACTED]", "original_unchanged": "alice"}
|
||||
|
||||
|
||||
def test_parse_arguments_manually_returns_original_if_no_match(proxy_logging):
|
||||
original = {"foo": "bar"}
|
||||
out = proxy_logging._parse_arguments_manually(args_text="nothing here", original_args=original)
|
||||
assert out == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_parse_arguments_manually_error_swallowed_returns_none(proxy_logging):
|
||||
# Defensive: function catches any exception internally and returns None.
|
||||
assert proxy_logging._parse_arguments_manually(args_text="x", original_args=None) is None # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _convert_llm_result_to_mcp_during_response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_during_response_exception(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj()
|
||||
result = proxy_logging._convert_llm_result_to_mcp_during_response(
|
||||
llm_result=ValueError("during boom"), request_obj=req
|
||||
)
|
||||
assert isinstance(result, MCPDuringCallResponseObject)
|
||||
snapshot = {
|
||||
"should_continue": result.should_continue,
|
||||
"error_message": result.error_message,
|
||||
"type": type(result).__name__,
|
||||
}
|
||||
assert snapshot == {
|
||||
"should_continue": False,
|
||||
"error_message": "during boom",
|
||||
"type": "MCPDuringCallResponseObject",
|
||||
}
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_during_response_blocked_content(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(tool_name="t", arguments={"a": 1})
|
||||
result = proxy_logging._convert_llm_result_to_mcp_during_response(
|
||||
llm_result={"messages": [{"content": "blocked content"}]},
|
||||
request_obj=req,
|
||||
)
|
||||
assert isinstance(result, MCPDuringCallResponseObject)
|
||||
assert result.should_continue is False
|
||||
assert "blocked" in (result.error_message or "").lower()
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_during_response_modified_stops(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(tool_name="t", arguments={"a": 1})
|
||||
result = proxy_logging._convert_llm_result_to_mcp_during_response(
|
||||
llm_result={"messages": [{"content": "Tool: t\nArguments: {\"a\": \"[REDACTED]\"}"}]},
|
||||
request_obj=req,
|
||||
)
|
||||
assert isinstance(result, MCPDuringCallResponseObject)
|
||||
assert result.should_continue is False
|
||||
assert "modified" in (result.error_message or "").lower()
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_during_response_string_blocks(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj()
|
||||
result = proxy_logging._convert_llm_result_to_mcp_during_response(
|
||||
llm_result="kill switch", request_obj=req
|
||||
)
|
||||
assert isinstance(result, MCPDuringCallResponseObject)
|
||||
snapshot = {"should_continue": result.should_continue, "error_message": result.error_message}
|
||||
assert snapshot == {"should_continue": False, "error_message": "kill switch"}
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_during_response_unmodified_returns_none(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(tool_name="t", arguments={"a": 1})
|
||||
same = "Tool: t\nArguments: {'a': 1}"
|
||||
assert (
|
||||
proxy_logging._convert_llm_result_to_mcp_during_response(
|
||||
llm_result={"messages": [{"content": same}]},
|
||||
request_obj=req,
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_convert_llm_result_to_mcp_during_response_no_request_obj_raises(proxy_logging):
|
||||
with pytest.raises(AttributeError):
|
||||
proxy_logging._convert_llm_result_to_mcp_during_response(
|
||||
llm_result={"messages": [{"content": "x"}]}, request_obj=None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_pre_mcp_call_hook_response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_pre_mcp_call_hook_response_with_modified_args(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(arguments={"a": 1})
|
||||
resp = MCPPreCallResponseObject(
|
||||
should_proceed=True,
|
||||
modified_arguments={"a": "x", "b": "y"},
|
||||
error_message=None,
|
||||
)
|
||||
out = proxy_logging._parse_pre_mcp_call_hook_response(response=resp, original_request=req)
|
||||
snapshot = {
|
||||
"should_proceed": out["should_proceed"],
|
||||
"modified_arguments": out["modified_arguments"],
|
||||
"error_message": out["error_message"],
|
||||
"hidden_params_type": type(out["hidden_params"]).__name__,
|
||||
}
|
||||
assert snapshot == {
|
||||
"should_proceed": True,
|
||||
"modified_arguments": {"a": "x", "b": "y"},
|
||||
"error_message": None,
|
||||
"hidden_params_type": "HiddenParams",
|
||||
}
|
||||
|
||||
|
||||
def test_parse_pre_mcp_call_hook_response_no_modifications_uses_original(proxy_logging, make_mcp_request_obj):
|
||||
req = make_mcp_request_obj(arguments={"original": True})
|
||||
resp = MCPPreCallResponseObject(
|
||||
should_proceed=True, modified_arguments=None, error_message=None
|
||||
)
|
||||
out = proxy_logging._parse_pre_mcp_call_hook_response(response=resp, original_request=req)
|
||||
assert out["modified_arguments"] == {"original": True}
|
||||
|
||||
|
||||
def test_parse_pre_mcp_call_hook_response_invalid_response_raises(proxy_logging, make_mcp_request_obj):
|
||||
with pytest.raises(AttributeError):
|
||||
proxy_logging._parse_pre_mcp_call_hook_response(
|
||||
response=None, original_request=make_mcp_request_obj()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _create_mcp_request_object_from_kwargs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_mcp_request_object_from_kwargs_full(proxy_logging, make_user_api_key_auth):
|
||||
auth = make_user_api_key_auth(user_id="u-1")
|
||||
obj = proxy_logging._create_mcp_request_object_from_kwargs(
|
||||
kwargs={
|
||||
"name": "calc",
|
||||
"arguments": {"x": 1},
|
||||
"server_name": "math",
|
||||
"user_api_key_auth": auth,
|
||||
}
|
||||
)
|
||||
assert isinstance(obj, MCPPreCallRequestObject)
|
||||
snapshot = {
|
||||
"tool_name": obj.tool_name,
|
||||
"arguments": obj.arguments,
|
||||
"server_name": obj.server_name,
|
||||
"auth_user_id": obj.user_api_key_auth.get("user_id"),
|
||||
}
|
||||
assert snapshot == {"tool_name": "calc", "arguments": {"x": 1}, "server_name": "math", "auth_user_id": "u-1"}
|
||||
|
||||
|
||||
def test_create_mcp_request_object_from_kwargs_empty(proxy_logging):
|
||||
obj = proxy_logging._create_mcp_request_object_from_kwargs(kwargs={})
|
||||
snapshot = {
|
||||
"tool_name": obj.tool_name,
|
||||
"arguments": obj.arguments,
|
||||
"server_name": obj.server_name,
|
||||
}
|
||||
assert snapshot == {"tool_name": "", "arguments": {}, "server_name": None}
|
||||
|
||||
|
||||
def test_create_mcp_request_object_from_kwargs_non_dict_raises(proxy_logging):
|
||||
with pytest.raises(AttributeError):
|
||||
proxy_logging._create_mcp_request_object_from_kwargs(kwargs=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _convert_mcp_hook_response_to_kwargs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_mcp_hook_response_to_kwargs_applies_modified_args(proxy_logging):
|
||||
original = {"arguments": {"a": 1}, "name": "old"}
|
||||
out = proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
response_data={"modified_arguments": {"a": 2}, "extra_headers": {"H": "1"}},
|
||||
original_kwargs=original,
|
||||
)
|
||||
snapshot = {
|
||||
"arguments": out["arguments"],
|
||||
"extra_headers": out["extra_headers"],
|
||||
"name": out["name"],
|
||||
"original_unmodified": original["arguments"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"arguments": {"a": 2},
|
||||
"extra_headers": {"H": "1"},
|
||||
"name": "old",
|
||||
"original_unmodified": {"a": 1},
|
||||
}
|
||||
|
||||
|
||||
def test_convert_mcp_hook_response_to_kwargs_merges_headers(proxy_logging):
|
||||
original = {"extra_headers": {"keep": "yes", "overwrite": "old"}}
|
||||
out = proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
response_data={"extra_headers": {"overwrite": "new", "added": "1"}},
|
||||
original_kwargs=original,
|
||||
)
|
||||
assert out["extra_headers"] == {"keep": "yes", "overwrite": "new", "added": "1"}
|
||||
|
||||
|
||||
def test_convert_mcp_hook_response_to_kwargs_no_response_data_returns_original(proxy_logging):
|
||||
original = {"a": 1}
|
||||
out = proxy_logging._convert_mcp_hook_response_to_kwargs(response_data=None, original_kwargs=original)
|
||||
assert out is original
|
||||
|
||||
|
||||
def test_convert_mcp_hook_response_to_kwargs_invalid_original_raises(proxy_logging):
|
||||
with pytest.raises(AttributeError):
|
||||
proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
response_data={"modified_arguments": {"a": 1}}, original_kwargs=None # type: ignore[arg-type]
|
||||
)
|
||||
@ -0,0 +1,353 @@
|
||||
"""Pin behavior of top-of-file and bottom-of-region helpers.
|
||||
|
||||
Covers ``print_verbose``, ``_get_email_logger_class``,
|
||||
``_accepts_litellm_call_info``, ``_enrich_http_exception_with_guardrail_context``,
|
||||
``on_backoff``, ``jsonify_object``, ``_lookup_deprecated_key``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.proxy import utils as utils_mod
|
||||
from litellm.proxy.utils import (
|
||||
_accepts_litellm_call_info,
|
||||
_enrich_http_exception_with_guardrail_context,
|
||||
_get_email_logger_class,
|
||||
_lookup_deprecated_key,
|
||||
jsonify_object,
|
||||
on_backoff,
|
||||
print_verbose,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# print_verbose
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_print_verbose_when_set_verbose_true_prints_redacted(monkeypatch, capsys):
|
||||
monkeypatch.setattr(litellm, "set_verbose", True)
|
||||
print_verbose("hello world")
|
||||
captured = capsys.readouterr()
|
||||
snapshot = {
|
||||
"out_has_prefix": "LiteLLM Proxy:" in captured.out,
|
||||
"out_has_payload": "hello world" in captured.out,
|
||||
"no_stderr": captured.err == "",
|
||||
}
|
||||
assert snapshot == {"out_has_prefix": True, "out_has_payload": True, "no_stderr": True}
|
||||
|
||||
|
||||
def test_print_verbose_when_set_verbose_false_no_stdout(monkeypatch, capsys):
|
||||
monkeypatch.setattr(litellm, "set_verbose", False)
|
||||
print_verbose("quiet")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == ""
|
||||
|
||||
|
||||
def test_print_verbose_handles_unprintable_object_raises(monkeypatch):
|
||||
monkeypatch.setattr(litellm, "set_verbose", True)
|
||||
|
||||
class Bomb:
|
||||
def __str__(self):
|
||||
raise RuntimeError("bad str")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
print_verbose(Bomb())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_email_logger_class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_email_logger_class_priority_matrix(monkeypatch):
|
||||
"""Truth table for ``_get_email_logger_class`` priority: SendGrid >
|
||||
Resend > SMTP > Base."""
|
||||
sg = object()
|
||||
rs = object()
|
||||
smtp = object()
|
||||
base = object()
|
||||
monkeypatch.setattr(utils_mod, "BaseEmailLogger", base)
|
||||
monkeypatch.setattr(utils_mod, "SendGridEmailLogger", sg)
|
||||
monkeypatch.setattr(utils_mod, "ResendEmailLogger", rs)
|
||||
monkeypatch.setattr(utils_mod, "SMTPEmailLogger", smtp)
|
||||
for k in ("SENDGRID_API_KEY", "RESEND_API_KEY", "SMTP_HOST"):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
|
||||
fallback = _get_email_logger_class() is base
|
||||
monkeypatch.setenv("SMTP_HOST", "smtp.example")
|
||||
smtp_choice = _get_email_logger_class() is smtp
|
||||
monkeypatch.setenv("RESEND_API_KEY", "rs-x")
|
||||
resend_choice = _get_email_logger_class() is rs
|
||||
monkeypatch.setenv("SENDGRID_API_KEY", "sg-x")
|
||||
sendgrid_choice = _get_email_logger_class() is sg
|
||||
snapshot = {
|
||||
"fallback_to_base": fallback,
|
||||
"smtp_when_smtp_only": smtp_choice,
|
||||
"resend_beats_smtp": resend_choice,
|
||||
"sendgrid_wins": sendgrid_choice,
|
||||
}
|
||||
assert snapshot == {
|
||||
"fallback_to_base": True,
|
||||
"smtp_when_smtp_only": True,
|
||||
"resend_beats_smtp": True,
|
||||
"sendgrid_wins": True,
|
||||
}
|
||||
|
||||
|
||||
def test_get_email_logger_class_error_when_no_enterprise_module(monkeypatch):
|
||||
monkeypatch.setattr(utils_mod, "BaseEmailLogger", None)
|
||||
# Returns ``None`` rather than raising; this is the documented failure
|
||||
# mode when the optional enterprise package is missing.
|
||||
assert _get_email_logger_class() is None
|
||||
# Sentinel: monkey-patch SendGrid env but keep BaseEmailLogger None;
|
||||
# function still must return None and not blow up on the optional path.
|
||||
monkeypatch.setenv("SENDGRID_API_KEY", "sg-x")
|
||||
assert _get_email_logger_class() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _accepts_litellm_call_info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _CbAcceptsInfo:
|
||||
async def async_post_call_response_headers_hook(self, *, litellm_call_info=None):
|
||||
return None
|
||||
|
||||
|
||||
class _CbRejectsInfo:
|
||||
async def async_post_call_response_headers_hook(self, *, response):
|
||||
return None
|
||||
|
||||
|
||||
def test_accepts_litellm_call_info_matrix(monkeypatch):
|
||||
monkeypatch.setattr(utils_mod, "_CALLBACK_ACCEPTS_CALL_INFO", {})
|
||||
cache = {id(_CbAcceptsInfo): True}
|
||||
monkeypatch.setattr(utils_mod, "_CALLBACK_ACCEPTS_CALL_INFO", cache)
|
||||
snapshot = {
|
||||
"cache_hit_returns_true": _accepts_litellm_call_info(_CbAcceptsInfo()),
|
||||
"cache_size_after_hit": len(cache),
|
||||
"cache_keyed_by_type_id": id(_CbAcceptsInfo) in cache,
|
||||
}
|
||||
assert snapshot == {
|
||||
"cache_hit_returns_true": True,
|
||||
"cache_size_after_hit": 1,
|
||||
"cache_keyed_by_type_id": True,
|
||||
}
|
||||
|
||||
|
||||
def test_accepts_litellm_call_info_signature_inspection(monkeypatch):
|
||||
monkeypatch.setattr(utils_mod, "_CALLBACK_ACCEPTS_CALL_INFO", {})
|
||||
snapshot = {
|
||||
"accepts_param_true": _accepts_litellm_call_info(_CbAcceptsInfo()),
|
||||
"rejects_param_false": _accepts_litellm_call_info(_CbRejectsInfo()),
|
||||
"cache_populated": len(utils_mod._CALLBACK_ACCEPTS_CALL_INFO) == 2,
|
||||
}
|
||||
assert snapshot == {
|
||||
"accepts_param_true": True,
|
||||
"rejects_param_false": False,
|
||||
"cache_populated": True,
|
||||
}
|
||||
|
||||
|
||||
def test_accepts_litellm_call_info_error_on_callback_without_hook_raises(monkeypatch):
|
||||
monkeypatch.setattr(utils_mod, "_CALLBACK_ACCEPTS_CALL_INFO", {})
|
||||
|
||||
class _Bad:
|
||||
pass
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_accepts_litellm_call_info(_Bad()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _enrich_http_exception_with_guardrail_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_enrich_http_exception_adds_guardrail_name_and_mode():
|
||||
detail = {"error": "blocked"}
|
||||
exc = HTTPException(status_code=400, detail=detail)
|
||||
cb = MagicMock()
|
||||
cb.guardrail_name = "presidio"
|
||||
cb.event_hook = "pre_call"
|
||||
|
||||
_enrich_http_exception_with_guardrail_context(exc, cb)
|
||||
snapshot = {
|
||||
"error": detail["error"],
|
||||
"guardrail_name": detail["guardrail_name"],
|
||||
"guardrail_mode": detail["guardrail_mode"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"error": "blocked",
|
||||
"guardrail_name": "presidio",
|
||||
"guardrail_mode": "pre_call",
|
||||
}
|
||||
|
||||
|
||||
def test_enrich_http_exception_does_not_overwrite_existing_keys():
|
||||
detail = {"error": "blocked", "guardrail_name": "explicit", "guardrail_mode": "during_call"}
|
||||
exc = HTTPException(status_code=400, detail=detail)
|
||||
cb = MagicMock()
|
||||
cb.guardrail_name = "should-not-overwrite"
|
||||
cb.event_hook = "should-not-overwrite"
|
||||
_enrich_http_exception_with_guardrail_context(exc, cb)
|
||||
assert detail == {"error": "blocked", "guardrail_name": "explicit", "guardrail_mode": "during_call"}
|
||||
|
||||
|
||||
def test_enrich_http_exception_no_op_for_non_http_exception():
|
||||
other = ValueError("not http")
|
||||
_enrich_http_exception_with_guardrail_context(other, MagicMock(guardrail_name="g"))
|
||||
|
||||
|
||||
def test_enrich_http_exception_no_op_for_non_dict_detail():
|
||||
exc = HTTPException(status_code=400, detail="just a string")
|
||||
_enrich_http_exception_with_guardrail_context(exc, MagicMock(guardrail_name="g"))
|
||||
assert exc.detail == "just a string"
|
||||
|
||||
|
||||
def test_enrich_http_exception_error_handling_does_not_raise():
|
||||
"""``_enrich_http_exception_with_guardrail_context`` swallows mismatched
|
||||
inputs (non-HTTPException, non-dict detail, no guardrail_name) and never
|
||||
raises — verified by passing each pathological input in turn."""
|
||||
# Bare exception with no detail at all should not blow up.
|
||||
bare = Exception("bare")
|
||||
_enrich_http_exception_with_guardrail_context(bare, MagicMock(guardrail_name=None))
|
||||
# HTTPException with non-dict detail.
|
||||
s = HTTPException(status_code=500, detail="str-detail")
|
||||
_enrich_http_exception_with_guardrail_context(s, MagicMock(guardrail_name="g"))
|
||||
assert s.detail == "str-detail"
|
||||
|
||||
|
||||
def test_enrich_http_exception_with_falsy_attrs_does_not_set():
|
||||
detail = {"error": "blocked"}
|
||||
exc = HTTPException(status_code=400, detail=detail)
|
||||
cb = MagicMock()
|
||||
cb.guardrail_name = None
|
||||
cb.event_hook = None
|
||||
_enrich_http_exception_with_guardrail_context(exc, cb)
|
||||
assert detail == {"error": "blocked"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_backoff
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_on_backoff_invokes_print_verbose(monkeypatch):
|
||||
captured = []
|
||||
monkeypatch.setattr(utils_mod, "print_verbose", lambda s: captured.append(s))
|
||||
on_backoff({"tries": 3})
|
||||
snapshot = {"len": len(captured), "first_has_attempt": "attempt" in captured[0], "first_has_3": "3" in captured[0]}
|
||||
assert snapshot == {"len": 1, "first_has_attempt": True, "first_has_3": True}
|
||||
|
||||
|
||||
def test_on_backoff_missing_tries_key_raises():
|
||||
with pytest.raises(KeyError):
|
||||
on_backoff({})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# jsonify_object
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_jsonify_object_serializes_nested_dicts():
|
||||
src = {"plain": "x", "nested": {"a": 1, "b": 2}, "n": 42}
|
||||
out = jsonify_object(src)
|
||||
expected = {"plain": "x", "nested": '{"a": 1, "b": 2}', "n": 42}
|
||||
assert out == expected
|
||||
# Source is not mutated.
|
||||
assert src == {"plain": "x", "nested": {"a": 1, "b": 2}, "n": 42}
|
||||
|
||||
|
||||
def test_jsonify_object_failed_serialization_marks_value(monkeypatch):
|
||||
class Unserialiseable:
|
||||
pass
|
||||
|
||||
src = {"name": "x", "bad": {"obj": Unserialiseable()}, "count": 1}
|
||||
out = jsonify_object(src)
|
||||
assert out == {"name": "x", "bad": "failed-to-serialize-json", "count": 1}
|
||||
|
||||
|
||||
def test_jsonify_object_non_dict_input_raises():
|
||||
with pytest.raises(AttributeError):
|
||||
jsonify_object("not a dict") # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _lookup_deprecated_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_deprecated_key_returns_active_token_id_and_caches(monkeypatch):
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
|
||||
fresh = LimitedSizeOrderedDict(max_size=1000)
|
||||
monkeypatch.setattr(utils_mod, "_deprecated_key_cache", fresh)
|
||||
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
deprecated_row = MagicMock()
|
||||
deprecated_row.active_token_id = "active-123"
|
||||
deprecated_row.revoke_at = future
|
||||
|
||||
db = MagicMock()
|
||||
db.litellm_deprecatedverificationtoken.find_first = AsyncMock(return_value=deprecated_row)
|
||||
|
||||
result = await _lookup_deprecated_key(db=db, hashed_token="hash-abc")
|
||||
cached_value = fresh.get("hash-abc")
|
||||
snapshot = {
|
||||
"result": result,
|
||||
"cache_active_token_id": cached_value[0],
|
||||
"cache_has_3_tuple": isinstance(cached_value, tuple) and len(cached_value) == 3,
|
||||
}
|
||||
assert snapshot == {
|
||||
"result": "active-123",
|
||||
"cache_active_token_id": "active-123",
|
||||
"cache_has_3_tuple": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_deprecated_key_returns_none_when_not_found(monkeypatch):
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
|
||||
monkeypatch.setattr(utils_mod, "_deprecated_key_cache", LimitedSizeOrderedDict(max_size=10))
|
||||
db = MagicMock()
|
||||
db.litellm_deprecatedverificationtoken.find_first = AsyncMock(return_value=None)
|
||||
assert await _lookup_deprecated_key(db=db, hashed_token="missing") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_deprecated_key_db_error_returns_none(monkeypatch):
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
|
||||
monkeypatch.setattr(utils_mod, "_deprecated_key_cache", LimitedSizeOrderedDict(max_size=10))
|
||||
db = MagicMock()
|
||||
db.litellm_deprecatedverificationtoken.find_first = AsyncMock(side_effect=RuntimeError("db down"))
|
||||
result = await _lookup_deprecated_key(db=db, hashed_token="x")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_deprecated_key_uses_cache_within_ttl(monkeypatch):
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
|
||||
cache = LimitedSizeOrderedDict(max_size=10)
|
||||
now_ts = datetime.now(timezone.utc).timestamp()
|
||||
cache["hashY"] = ("active-from-cache", now_ts + 100, now_ts + 1000)
|
||||
monkeypatch.setattr(utils_mod, "_deprecated_key_cache", cache)
|
||||
|
||||
db = MagicMock()
|
||||
db.litellm_deprecatedverificationtoken.find_first = AsyncMock(return_value=None)
|
||||
result = await _lookup_deprecated_key(db=db, hashed_token="hashY")
|
||||
assert result == "active-from-cache"
|
||||
db.litellm_deprecatedverificationtoken.find_first.assert_not_called()
|
||||
@ -0,0 +1,269 @@
|
||||
"""Pin ``ProxyLogging.post_call_failure_hook``, ``_is_proxy_only_llm_api_error``,
|
||||
and ``_handle_logging_proxy_only_error``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import AlertType, ProxyErrorTypes
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_caps_cache():
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
yield
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_proxy_only_llm_api_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_proxy_only_llm_api_truth_table(proxy_logging):
|
||||
"""Pin the truth table of ``_is_proxy_only_llm_api_error`` in a single
|
||||
snapshot. Covers no-route, non-LLM route, HTTPException on LLM route,
|
||||
and auth-error short-circuit."""
|
||||
snapshot = {
|
||||
"no_route": proxy_logging._is_proxy_only_llm_api_error(
|
||||
original_exception=Exception(), route=None
|
||||
),
|
||||
"non_llm_route": proxy_logging._is_proxy_only_llm_api_error(
|
||||
original_exception=HTTPException(status_code=429, detail="rate"),
|
||||
route="/random/path",
|
||||
),
|
||||
"http_on_llm_route": proxy_logging._is_proxy_only_llm_api_error(
|
||||
original_exception=HTTPException(status_code=429, detail="rate"),
|
||||
route="/chat/completions",
|
||||
),
|
||||
"auth_short_circuit": proxy_logging._is_proxy_only_llm_api_error(
|
||||
original_exception=Exception("auth"),
|
||||
error_type=ProxyErrorTypes.auth_error,
|
||||
route="/chat/completions",
|
||||
),
|
||||
}
|
||||
assert snapshot == {
|
||||
"no_route": False,
|
||||
"non_llm_route": False,
|
||||
"http_on_llm_route": True,
|
||||
"auth_short_circuit": True,
|
||||
}
|
||||
|
||||
|
||||
def test_is_proxy_only_llm_api_missing_exception_raises(proxy_logging):
|
||||
"""Passing nothing should TypeError on the missing positional kwarg."""
|
||||
with pytest.raises(TypeError):
|
||||
proxy_logging._is_proxy_only_llm_api_error() # type: ignore[call-arg]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# post_call_failure_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_failure_hook_no_callbacks_returns_none(
|
||||
proxy_logging, make_user_api_key_auth, mock_callbacks_disabled
|
||||
):
|
||||
proxy_logging.alert_types = []
|
||||
request_data = {"litellm_call_id": "abc", "model": "m", "messages": []}
|
||||
out = await proxy_logging.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=ValueError("oops"),
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
)
|
||||
snapshot = {
|
||||
"out_is_none": out is None,
|
||||
"litellm_logging_obj_popped": "litellm_logging_obj" not in request_data,
|
||||
"call_id_preserved": request_data["litellm_call_id"] == "abc",
|
||||
"first_api_call_start_time_present": "first_api_call_start_time" in request_data,
|
||||
}
|
||||
assert snapshot == {
|
||||
"out_is_none": True,
|
||||
"litellm_logging_obj_popped": True,
|
||||
"call_id_preserved": True,
|
||||
"first_api_call_start_time_present": False,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_failure_hook_callback_returns_http_exception(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
transformed = HTTPException(status_code=418, detail="teapot")
|
||||
|
||||
class _Cb(CustomLogger):
|
||||
async def async_post_call_failure_hook(self, **kwargs): # type: ignore[override]
|
||||
return transformed
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
|
||||
proxy_logging.alert_types = []
|
||||
out = await proxy_logging.post_call_failure_hook(
|
||||
request_data={"litellm_call_id": "abc"},
|
||||
original_exception=ValueError("oops"),
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
)
|
||||
assert out is transformed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_failure_hook_callback_raises_http_exception_first_wins(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
err = HTTPException(status_code=418, detail="raised teapot")
|
||||
|
||||
class _Cb(CustomLogger):
|
||||
async def async_post_call_failure_hook(self, **kwargs): # type: ignore[override]
|
||||
raise err
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
|
||||
proxy_logging.alert_types = []
|
||||
out = await proxy_logging.post_call_failure_hook(
|
||||
request_data={"litellm_call_id": "abc"},
|
||||
original_exception=ValueError("oops"),
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
)
|
||||
assert out is err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_failure_hook_non_http_exception_in_callback_swallowed(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
class _Cb(CustomLogger):
|
||||
async def async_post_call_failure_hook(self, **kwargs): # type: ignore[override]
|
||||
raise RuntimeError("non-http inside cb")
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
|
||||
proxy_logging.alert_types = []
|
||||
out = await proxy_logging.post_call_failure_hook(
|
||||
request_data={"litellm_call_id": "abc"},
|
||||
original_exception=ValueError("oops"),
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
)
|
||||
assert out is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_logging_proxy_only_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_logging_proxy_only_path_uses_existing_logging_obj(
|
||||
proxy_logging, make_user_api_key_auth
|
||||
):
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.call_type = "acompletion"
|
||||
logging_obj.model_call_details = {}
|
||||
logging_obj.async_failure_handler = AsyncMock()
|
||||
|
||||
request_data = {
|
||||
"litellm_logging_obj": logging_obj,
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"model": "m",
|
||||
"metadata": {},
|
||||
}
|
||||
await proxy_logging._handle_logging_proxy_only_error(
|
||||
request_data=request_data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
route="/chat/completions",
|
||||
original_exception=HTTPException(status_code=429, detail="rate"),
|
||||
)
|
||||
from litellm.constants import LITELLM_LOGGING_NO_UPSTREAM_LLM_CALL
|
||||
|
||||
snapshot = {
|
||||
"input_logged": "messages" in logging_obj.model_call_details,
|
||||
"call_type_normalized": logging_obj.call_type,
|
||||
"marker_present": logging_obj.model_call_details.get(
|
||||
LITELLM_LOGGING_NO_UPSTREAM_LLM_CALL
|
||||
)
|
||||
is True,
|
||||
"async_failure_called": logging_obj.async_failure_handler.called,
|
||||
}
|
||||
assert snapshot == {
|
||||
"input_logged": True,
|
||||
"call_type_normalized": "acompletion",
|
||||
"marker_present": True,
|
||||
"async_failure_called": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_logging_proxy_only_path_skips_for_pass_through(
|
||||
proxy_logging, make_user_api_key_auth
|
||||
):
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.call_type = CallTypes.pass_through.value
|
||||
logging_obj.model_call_details = {}
|
||||
logging_obj.async_failure_handler = AsyncMock()
|
||||
logging_obj.pre_call = MagicMock()
|
||||
request_data = {
|
||||
"litellm_logging_obj": logging_obj,
|
||||
"messages": [{"role": "user"}],
|
||||
"model": "m",
|
||||
}
|
||||
await proxy_logging._handle_logging_proxy_only_error(
|
||||
request_data=request_data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
route="/chat/completions",
|
||||
original_exception=HTTPException(status_code=429, detail="rate"),
|
||||
)
|
||||
logging_obj.pre_call.assert_not_called()
|
||||
logging_obj.async_failure_handler.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_logging_proxy_only_path_no_logging_obj_creates_one(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
fake_logging_obj = MagicMock()
|
||||
fake_logging_obj.call_type = "acompletion"
|
||||
fake_logging_obj.model_call_details = {}
|
||||
fake_logging_obj.async_failure_handler = AsyncMock()
|
||||
|
||||
def fake_function_setup(**kwargs):
|
||||
return fake_logging_obj, {}
|
||||
|
||||
monkeypatch.setattr(litellm.utils, "function_setup", fake_function_setup)
|
||||
request_data = {"messages": [{"role": "user"}], "model": "m"}
|
||||
await proxy_logging._handle_logging_proxy_only_error(
|
||||
request_data=request_data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
route="/chat/completions",
|
||||
original_exception=HTTPException(status_code=429, detail="rate"),
|
||||
)
|
||||
assert "litellm_call_id" in request_data
|
||||
fake_logging_obj.async_failure_handler.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_logging_proxy_only_path_propagates_async_failure_raises(
|
||||
proxy_logging, make_user_api_key_auth
|
||||
):
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.call_type = "acompletion"
|
||||
logging_obj.model_call_details = {}
|
||||
logging_obj.async_failure_handler = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
request_data = {
|
||||
"litellm_logging_obj": logging_obj,
|
||||
"messages": [{"role": "user"}],
|
||||
"model": "m",
|
||||
}
|
||||
with pytest.raises(RuntimeError):
|
||||
await proxy_logging._handle_logging_proxy_only_error(
|
||||
request_data=request_data,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
route="/chat/completions",
|
||||
original_exception=Exception("x"),
|
||||
)
|
||||
@ -0,0 +1,97 @@
|
||||
"""Pin ``ProxyLogging.post_call_success_hook``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_caps_cache():
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
yield
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
|
||||
def _make_guardrail(name="g", should_run=True, override=None):
|
||||
cb = MagicMock(spec=CustomGuardrail)
|
||||
cb.__class__ = CustomGuardrail
|
||||
cb.guardrail_name = name
|
||||
cb.event_hook = GuardrailEventHooks.post_call
|
||||
cb.should_run_guardrail = MagicMock(return_value=should_run)
|
||||
cb.async_post_call_success_hook = AsyncMock(return_value=override)
|
||||
return cb
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_returns_response_when_no_callbacks(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
|
||||
response = {"original": True, "model": "m", "choices": []}
|
||||
out = await proxy_logging.post_call_success_hook(
|
||||
data={}, response=response, user_api_key_dict=make_user_api_key_auth()
|
||||
)
|
||||
assert out == {"original": True, "model": "m", "choices": []}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_runs_other_callback_and_replaces_response(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
new_response = {"modified": True, "kept": "yes", "final": "v"}
|
||||
|
||||
class _CL(CustomLogger):
|
||||
async def async_post_call_success_hook(self, **kwargs): # type: ignore[override]
|
||||
return new_response
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_CL()])
|
||||
out = await proxy_logging.post_call_success_hook(
|
||||
data={}, response={"original": True}, user_api_key_dict=make_user_api_key_auth()
|
||||
)
|
||||
assert out == new_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_guardrail_should_not_run_skipped(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
g = _make_guardrail(should_run=False)
|
||||
monkeypatch.setattr(litellm, "callbacks", [g])
|
||||
response = MagicMock()
|
||||
out = await proxy_logging.post_call_success_hook(
|
||||
data={}, response=response, user_api_key_dict=make_user_api_key_auth()
|
||||
)
|
||||
g.async_post_call_success_hook.assert_not_called()
|
||||
assert out is response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_guardrail_error_raises(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
g = _make_guardrail()
|
||||
g.async_post_call_success_hook = AsyncMock(side_effect=RuntimeError("blocked"))
|
||||
monkeypatch.setattr(litellm, "callbacks", [g])
|
||||
with pytest.raises(RuntimeError):
|
||||
await proxy_logging.post_call_success_hook(
|
||||
data={}, response=MagicMock(), user_api_key_dict=make_user_api_key_auth()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_guardrail_returns_modified_response(
|
||||
proxy_logging, make_user_api_key_auth, monkeypatch
|
||||
):
|
||||
modified = {"a": 1, "b": 2, "c": 3}
|
||||
g = _make_guardrail(override=modified)
|
||||
monkeypatch.setattr(litellm, "callbacks", [g])
|
||||
out = await proxy_logging.post_call_success_hook(
|
||||
data={}, response={"orig": True}, user_api_key_dict=make_user_api_key_auth()
|
||||
)
|
||||
assert out == modified
|
||||
@ -0,0 +1,168 @@
|
||||
"""Pin ``ProxyLogging.pre_call_hook`` and ``process_pre_call_hook_response``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_caps_cache():
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
yield
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_pre_call_hook_response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_pre_call_hook_response_dict_returns_response(proxy_logging):
|
||||
out = await proxy_logging.process_pre_call_hook_response(
|
||||
response={"messages": [{"x": 1}], "model": "m", "temperature": 0.5},
|
||||
data={"original": True},
|
||||
call_type="completion",
|
||||
)
|
||||
assert out == {"messages": [{"x": 1}], "model": "m", "temperature": 0.5}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_pre_call_hook_response_string_completion_raises_rejected(proxy_logging):
|
||||
with pytest.raises(RejectedRequestError):
|
||||
await proxy_logging.process_pre_call_hook_response(
|
||||
response="rejected",
|
||||
data={"model": "m"},
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_pre_call_hook_response_string_other_call_type_raises_http(proxy_logging):
|
||||
with pytest.raises(HTTPException) as info:
|
||||
await proxy_logging.process_pre_call_hook_response(
|
||||
response="bad",
|
||||
data={},
|
||||
call_type="embeddings",
|
||||
)
|
||||
assert info.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_pre_call_hook_response_exception_reraises(proxy_logging):
|
||||
err = RuntimeError("hook said no")
|
||||
with pytest.raises(RuntimeError, match="hook said no"):
|
||||
await proxy_logging.process_pre_call_hook_response(
|
||||
response=err, data={}, call_type="completion"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_pre_call_hook_response_other_type_returns_data(proxy_logging):
|
||||
out = await proxy_logging.process_pre_call_hook_response(
|
||||
response=12345, data={"a": 1, "b": 2, "c": 3}, call_type="completion"
|
||||
)
|
||||
assert out == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pre_call_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_returns_data_when_no_callbacks(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
|
||||
data = {"messages": [{"role": "user", "content": "hi"}], "model": "m", "temperature": 0.7}
|
||||
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
|
||||
out = await proxy_logging.pre_call_hook(
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
data=data,
|
||||
call_type="completion",
|
||||
)
|
||||
assert out is data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_returns_none_for_none_data(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
|
||||
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
|
||||
out = await proxy_logging.pre_call_hook(
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
data=None,
|
||||
call_type="completion",
|
||||
)
|
||||
assert out is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_invokes_pre_call_override(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
class _Cb(CustomLogger):
|
||||
async def async_pre_call_hook(self, **kwargs): # type: ignore[override]
|
||||
captured.update(kwargs)
|
||||
return {"messages": [{"x": "modified"}], "model": "m", "temperature": 0.1}
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
|
||||
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
|
||||
out = await proxy_logging.pre_call_hook(
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
data={"messages": [{"x": "input"}], "model": "m", "temperature": 0.1},
|
||||
call_type="completion",
|
||||
)
|
||||
snapshot = {
|
||||
"out_messages": out["messages"],
|
||||
"out_model": out["model"],
|
||||
"out_temp": out["temperature"],
|
||||
"cb_received_call_type": captured.get("call_type"),
|
||||
}
|
||||
assert snapshot == {
|
||||
"out_messages": [{"x": "modified"}],
|
||||
"out_model": "m",
|
||||
"out_temp": 0.1,
|
||||
"cb_received_call_type": "completion",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_propagates_callback_error_raises(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
class _BadCb(CustomLogger):
|
||||
async def async_pre_call_hook(self, **kwargs): # type: ignore[override]
|
||||
raise RuntimeError("rejected")
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_BadCb()])
|
||||
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
|
||||
with pytest.raises(RuntimeError, match="rejected"):
|
||||
await proxy_logging.pre_call_hook(
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
data={"model": "m"},
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_processes_guardrail_metadata_when_no_overrides(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
|
||||
"""Even when no callback overrides exist, ``_process_guardrail_metadata`` runs."""
|
||||
data = {"messages": [{"role": "user"}], "model": "m", "metadata": {"guardrails": ["g1"]}}
|
||||
proxy_logging.slack_alerting_instance = MagicMock(alerting=None)
|
||||
invoked = {}
|
||||
|
||||
def fake_process(d):
|
||||
invoked["data"] = d
|
||||
|
||||
proxy_logging._process_guardrail_metadata = fake_process # type: ignore[assignment]
|
||||
out = await proxy_logging.pre_call_hook(
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
data=data,
|
||||
call_type="completion",
|
||||
)
|
||||
assert out is data
|
||||
assert invoked["data"] is data
|
||||
@ -0,0 +1,432 @@
|
||||
"""Pin ProxyLogging streaming + response-headers helpers.
|
||||
|
||||
Covers ``_wrap_streaming_iterator_with_enrichment``,
|
||||
``async_post_call_streaming_hook``,
|
||||
``async_post_call_streaming_iterator_hook``, ``_fire_deferred_stream_logging``,
|
||||
``is_a2a_streaming_response``, ``_init_response_taking_too_long_task``,
|
||||
``post_call_response_headers_hook``, ``_build_litellm_call_info``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_caps_cache():
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
yield
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_a2a_streaming_response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_a2a_streaming_response_truth_matrix(proxy_logging):
|
||||
snapshot = {
|
||||
"all_three_keys_present": proxy_logging.is_a2a_streaming_response(
|
||||
{"jsonrpc": "2.0", "id": "1", "result": {"x": 1}, "extra": "y"}
|
||||
),
|
||||
"missing_result": proxy_logging.is_a2a_streaming_response(
|
||||
{"jsonrpc": "2.0", "id": "1"}
|
||||
),
|
||||
"missing_jsonrpc": proxy_logging.is_a2a_streaming_response(
|
||||
{"id": "1", "result": {}}
|
||||
),
|
||||
"empty_dict": proxy_logging.is_a2a_streaming_response({}),
|
||||
}
|
||||
assert snapshot == {
|
||||
"all_three_keys_present": True,
|
||||
"missing_result": False,
|
||||
"missing_jsonrpc": False,
|
||||
"empty_dict": False,
|
||||
}
|
||||
|
||||
|
||||
def test_is_a2a_streaming_response_invalid_input_raises(proxy_logging):
|
||||
with pytest.raises(TypeError):
|
||||
proxy_logging.is_a2a_streaming_response(None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_litellm_call_info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_litellm_call_info_pulls_from_hidden_params_and_metadata(proxy_logging):
|
||||
response = MagicMock()
|
||||
response._hidden_params = {
|
||||
"custom_llm_provider": "openai",
|
||||
"api_base": "https://api.openai.com",
|
||||
"model_id": "model-1",
|
||||
}
|
||||
info = proxy_logging._build_litellm_call_info(
|
||||
data={"metadata": {"model_info": {"name": "gpt-4o-mini"}}},
|
||||
response=response,
|
||||
)
|
||||
assert info == {
|
||||
"custom_llm_provider": "openai",
|
||||
"model_info": {"name": "gpt-4o-mini"},
|
||||
"api_base": "https://api.openai.com",
|
||||
"model_id": "model-1",
|
||||
}
|
||||
|
||||
|
||||
def test_build_litellm_call_info_fallbacks_to_litellm_metadata(proxy_logging):
|
||||
response = MagicMock()
|
||||
response._hidden_params = {"custom_llm_provider": "azure"}
|
||||
info = proxy_logging._build_litellm_call_info(
|
||||
data={"litellm_metadata": {"model_info": {"alias": "azure-gpt"}}},
|
||||
response=response,
|
||||
)
|
||||
snapshot = {
|
||||
"custom_llm_provider": info["custom_llm_provider"],
|
||||
"model_info": info["model_info"],
|
||||
"api_base": info["api_base"],
|
||||
"model_id": info["model_id"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"custom_llm_provider": "azure",
|
||||
"model_info": {"alias": "azure-gpt"},
|
||||
"api_base": None,
|
||||
"model_id": None,
|
||||
}
|
||||
|
||||
|
||||
def test_build_litellm_call_info_invalid_data_raises(proxy_logging):
|
||||
with pytest.raises(AttributeError):
|
||||
proxy_logging._build_litellm_call_info(data=None, response=MagicMock()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _init_response_taking_too_long_task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_response_taking_too_long_task_runs_when_alerting(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = MagicMock()
|
||||
proxy_logging.slack_alerting_instance.alerting = ["slack"]
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_resp_too_long(request_data):
|
||||
captured["request_data"] = request_data
|
||||
|
||||
proxy_logging.slack_alerting_instance.response_taking_too_long = fake_resp_too_long
|
||||
payload = {"req": "y", "litellm_call_id": "c1", "model": "m"}
|
||||
proxy_logging._init_response_taking_too_long_task(data=payload)
|
||||
await asyncio.sleep(0)
|
||||
snapshot = {
|
||||
"received_payload": captured["request_data"],
|
||||
"fired_once": len(captured) == 1,
|
||||
"alerting_was_truthy": bool(proxy_logging.slack_alerting_instance.alerting),
|
||||
}
|
||||
assert snapshot == {
|
||||
"received_payload": payload,
|
||||
"fired_once": True,
|
||||
"alerting_was_truthy": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_response_taking_too_long_task_no_op_when_alerting_off(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = MagicMock()
|
||||
proxy_logging.slack_alerting_instance.alerting = None
|
||||
proxy_logging.slack_alerting_instance.response_taking_too_long = AsyncMock()
|
||||
proxy_logging._init_response_taking_too_long_task(data=None)
|
||||
await asyncio.sleep(0)
|
||||
proxy_logging.slack_alerting_instance.response_taking_too_long.assert_not_called()
|
||||
|
||||
|
||||
def test_init_response_taking_too_long_task_no_slack_instance_no_error_raises(proxy_logging):
|
||||
proxy_logging.slack_alerting_instance = None
|
||||
proxy_logging._init_response_taking_too_long_task(data=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _wrap_streaming_iterator_with_enrichment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrap_streaming_iterator_with_enrichment_passes_through_chunks(proxy_logging):
|
||||
async def gen():
|
||||
for ch in ("a", "b", "c"):
|
||||
yield ch
|
||||
|
||||
cb = MagicMock(guardrail_name="g", event_hook="pre_call")
|
||||
wrapped = proxy_logging._wrap_streaming_iterator_with_enrichment(callback=cb, gen=gen())
|
||||
out = [ch async for ch in wrapped]
|
||||
snapshot = {
|
||||
"chunks": out,
|
||||
"count": len(out),
|
||||
"first": out[0],
|
||||
"last": out[-1],
|
||||
}
|
||||
assert snapshot == {
|
||||
"chunks": ["a", "b", "c"],
|
||||
"count": 3,
|
||||
"first": "a",
|
||||
"last": "c",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrap_streaming_iterator_with_enrichment_enriches_http_exception_raises(proxy_logging):
|
||||
detail = {"error": "blocked"}
|
||||
|
||||
async def boom_gen():
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
raise HTTPException(status_code=400, detail=detail)
|
||||
|
||||
cb = MagicMock(guardrail_name="presidio", event_hook="post_call")
|
||||
wrapped = proxy_logging._wrap_streaming_iterator_with_enrichment(callback=cb, gen=boom_gen())
|
||||
with pytest.raises(HTTPException):
|
||||
async for _ in wrapped:
|
||||
pass
|
||||
assert detail["guardrail_name"] == "presidio"
|
||||
assert detail["guardrail_mode"] == "post_call"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# async_post_call_streaming_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_post_call_streaming_hook_fast_path_returns_response(proxy_logging, mock_callbacks_disabled, make_user_api_key_auth):
|
||||
resp = "chunk-1"
|
||||
out = await proxy_logging.async_post_call_streaming_hook(
|
||||
data={}, response=resp, user_api_key_dict=make_user_api_key_auth()
|
||||
)
|
||||
snapshot = {
|
||||
"out_is_input": out is resp,
|
||||
"out_value": out,
|
||||
"type": type(out).__name__,
|
||||
"callbacks_empty": len(litellm.callbacks) == 0,
|
||||
}
|
||||
assert snapshot == {
|
||||
"out_is_input": True,
|
||||
"out_value": "chunk-1",
|
||||
"type": "str",
|
||||
"callbacks_empty": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_post_call_streaming_hook_invokes_per_chunk_callback(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
class _Per(CustomLogger):
|
||||
async def async_post_call_streaming_hook(self, **kwargs): # type: ignore[override]
|
||||
return "modified-" + str(kwargs.get("response", ""))
|
||||
|
||||
cb = _Per()
|
||||
monkeypatch.setattr(litellm, "callbacks", [cb])
|
||||
|
||||
from litellm import ModelResponse
|
||||
|
||||
fake_resp = ModelResponse(
|
||||
id="rid",
|
||||
choices=[{"index": 0, "delta": {"role": "assistant", "content": "hi"}, "finish_reason": None}],
|
||||
created=0,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
out = await proxy_logging.async_post_call_streaming_hook(
|
||||
data={},
|
||||
response=fake_resp,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
)
|
||||
assert isinstance(out, str)
|
||||
assert out.startswith("modified-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_post_call_streaming_hook_callback_error_raises(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
class _Per(CustomLogger):
|
||||
async def async_post_call_streaming_hook(self, **kwargs): # type: ignore[override]
|
||||
raise RuntimeError("hook-fail")
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_Per()])
|
||||
|
||||
from litellm import ModelResponse
|
||||
|
||||
fake_resp = ModelResponse(
|
||||
id="rid",
|
||||
choices=[{"index": 0, "delta": {"role": "assistant", "content": "hi"}, "finish_reason": None}],
|
||||
created=0,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
await proxy_logging.async_post_call_streaming_hook(
|
||||
data={},
|
||||
response=fake_resp,
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# async_post_call_streaming_iterator_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_post_call_streaming_iterator_hook_no_overrides_passes_through(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
|
||||
async def gen():
|
||||
for ch in ("a", "b"):
|
||||
yield ch
|
||||
|
||||
chunks = []
|
||||
async for ch in proxy_logging.async_post_call_streaming_iterator_hook(
|
||||
response=gen(),
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
request_data={},
|
||||
):
|
||||
chunks.append(ch)
|
||||
snapshot = {
|
||||
"chunks": chunks,
|
||||
"count": len(chunks),
|
||||
"passthrough_preserved_order": chunks == ["a", "b"],
|
||||
}
|
||||
assert snapshot == {
|
||||
"chunks": ["a", "b"],
|
||||
"count": 2,
|
||||
"passthrough_preserved_order": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_post_call_streaming_iterator_hook_with_override_chains_callback(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
class _IterOverride(CustomLogger):
|
||||
async def async_post_call_streaming_iterator_hook(self, **kwargs): # type: ignore[override]
|
||||
async for ch in kwargs["response"]:
|
||||
yield ch + "*"
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_IterOverride()])
|
||||
|
||||
async def gen():
|
||||
for ch in ("a", "b"):
|
||||
yield ch
|
||||
|
||||
out: List[str] = []
|
||||
async for ch in proxy_logging.async_post_call_streaming_iterator_hook(
|
||||
response=gen(),
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
request_data={},
|
||||
):
|
||||
out.append(ch)
|
||||
assert out == ["a*", "b*"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_post_call_streaming_iterator_hook_upstream_error_raises(proxy_logging, make_user_api_key_auth, mock_callbacks_disabled):
|
||||
async def gen():
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
raise RuntimeError("upstream")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in proxy_logging.async_post_call_streaming_iterator_hook(
|
||||
response=gen(),
|
||||
user_api_key_dict=make_user_api_key_auth(),
|
||||
request_data={},
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _fire_deferred_stream_logging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fire_deferred_stream_logging_fires_callback():
|
||||
logging_obj = MagicMock()
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def deferred(arg):
|
||||
captured["arg"] = arg
|
||||
|
||||
logging_obj._on_deferred_stream_complete = deferred
|
||||
logging_obj._deferred_stream_complete_args = ("payload",)
|
||||
|
||||
ProxyLogging._fire_deferred_stream_logging(request_data={"litellm_logging_obj": logging_obj})
|
||||
await asyncio.sleep(0)
|
||||
snapshot = {
|
||||
"arg": captured["arg"],
|
||||
"callback_cleared": logging_obj._on_deferred_stream_complete is None,
|
||||
"args_cleared": logging_obj._deferred_stream_complete_args is None,
|
||||
}
|
||||
assert snapshot == {"arg": "payload", "callback_cleared": True, "args_cleared": True}
|
||||
|
||||
|
||||
def test_fire_deferred_stream_logging_no_logging_obj_no_error():
|
||||
ProxyLogging._fire_deferred_stream_logging(request_data={})
|
||||
|
||||
|
||||
def test_fire_deferred_stream_logging_missing_obj_raises_on_invalid_dict():
|
||||
with pytest.raises(AttributeError):
|
||||
ProxyLogging._fire_deferred_stream_logging(request_data=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# post_call_response_headers_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_response_headers_hook_returns_empty_when_no_callbacks(
|
||||
proxy_logging, mock_callbacks_disabled, make_user_api_key_auth
|
||||
):
|
||||
out = await proxy_logging.post_call_response_headers_hook(
|
||||
data={}, user_api_key_dict=make_user_api_key_auth(), response=MagicMock(_hidden_params={})
|
||||
)
|
||||
assert out == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_response_headers_hook_merges_callback_headers(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
class _Cb(CustomLogger):
|
||||
async def async_post_call_response_headers_hook(self, **kwargs): # type: ignore[override]
|
||||
return {"X-One": "1", "X-Two": "2", "X-Common": "first"}
|
||||
|
||||
class _Cb2(CustomLogger):
|
||||
async def async_post_call_response_headers_hook(self, **kwargs): # type: ignore[override]
|
||||
return {"X-Common": "second", "X-Three": "3"}
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_Cb(), _Cb2()])
|
||||
response = MagicMock()
|
||||
response._hidden_params = {}
|
||||
out = await proxy_logging.post_call_response_headers_hook(
|
||||
data={}, user_api_key_dict=make_user_api_key_auth(), response=response
|
||||
)
|
||||
assert out == {"X-One": "1", "X-Two": "2", "X-Common": "second", "X-Three": "3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_response_headers_hook_swallows_callback_error(proxy_logging, make_user_api_key_auth, monkeypatch):
|
||||
"""Errors inside the hook are caught — function returns merged so-far."""
|
||||
|
||||
class _Cb(CustomLogger):
|
||||
async def async_post_call_response_headers_hook(self, **kwargs): # type: ignore[override]
|
||||
raise RuntimeError("bad header")
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_Cb()])
|
||||
response = MagicMock()
|
||||
response._hidden_params = {}
|
||||
out = await proxy_logging.post_call_response_headers_hook(
|
||||
data={}, user_api_key_dict=make_user_api_key_auth(), response=response
|
||||
)
|
||||
assert out == {}
|
||||
Loading…
Reference in New Issue
Block a user