[internal copy of #29511] feat(guardrails): add sensitive data routing to on-premise models (#29531)
* feat(guardrails): add sensitive data routing to on-premise models When a guardrail detects sensitive data, route to an on-premise model instead of blocking or redacting. All subsequent requests in that session continue routing to the same model (sticky routing). New config options for guardrails: - on_sensitive_data: 'block' (default) or 'route' - sensitive_data_route_to_model: target model for rerouting - sticky_session_routing: persist routing for session (default: true) New exception SensitiveDataRouteException triggers rerouting when raised by guardrails. The proxy catches it, stores the routing decision in cache, and modifies the request's model field. New hook _PROXY_SensitiveDataRoutingHandler checks incoming requests against cached routing decisions and applies sticky routing. https://claude.ai/code/session_01SQd4isBa3UyouRoGVou9dK * fix: black formatting for custom_guardrail.py https://claude.ai/code/session_01SQd4isBa3UyouRoGVou9dK * test: improve test coverage for sensitive data routing feature Add additional tests for: - Cache key format and TTL constants - Session ID extraction from multiple locations - Custom guardrail initialization with routing config - Exception string representation and custom messages - Redis cache paths including fallback behavior - Edge cases in pre-call hook https://claude.ai/code/session_01SQd4isBa3UyouRoGVou9dK * fix: use correct GuardrailRaisedException parameters Replace invalid 'source' parameter with 'guardrail_name' to match the exception's actual signature. https://claude.ai/code/session_01SQd4isBa3UyouRoGVou9dK * test: move sensitive data routing tests to hooks directory Move test file to align with source code structure. https://claude.ai/code/session_01SQd4isBa3UyouRoGVou9dK * fix(guardrails): honor sticky_session_routing flag and scope session routing per API key Propagate sticky_session_routing through SensitiveDataRouteException so a guardrail configured with sticky_session_routing=False reroutes only the triggering request without persisting a session override. Scope the routing cache key to the requesting API key so sessions from different tenants cannot collide, and warn when sticky routing is requested but the hook is not registered. * refactor(guardrails): dedupe session-id extraction and drop redundant import Extract the shared session-id lookup into get_session_id_from_request_data so the sensitive-data routing hook and CustomGuardrail no longer keep two identical copies of the logic. Remove the redundant local import of GuardrailRaisedException in handle_sensitive_data_detection, and document that detection_info is surfaced in request metadata and logs so it must not carry raw sensitive values. * fix(guardrails): guard None user_api_key_dict in sensitive data route handler * fix(responses): send application/json Content-Type on responses DELETE OpenAI's responses DELETE endpoint now rejects requests that arrive without a Content-Type header, defaulting them to application/octet-stream and returning 'Unsupported content type: application/octet-stream'. The delete handler sent no body and therefore no Content-Type, so the request failed. Declare application/json on the delete request, matching the OpenAI SDK. * fix(guardrails): backfill in-memory cache after redis hit in sensitive data routing When _get_routed_model resolves a routing override from Redis it now also populates the local in-memory cache. Without the write-back, a non-writing instance that only ever reads from Redis would lose the sticky routing decision the moment Redis became unavailable, silently reverting sensitive sessions to the default model. * fix(guardrails): scope sticky sensitive-data routing to JWT principal Keyless auth (JWT and similar) has no api_key, so every such caller shared the "default" cache namespace. One authenticated user could reuse another user's session_id, trip the guardrail, and silently force the other user's subsequent requests onto the cached on-prem model for the TTL. Resolve the routing tenant from the api_key when present, otherwise from a stable principal built from the user/team/org identity, before reading or writing the session route. * fix(guardrails): require route target model when on_sensitive_data='route' * fix(guardrails): mark user_api_key_dict Optional in sensitive-data route handler * fix(guardrails): use remaining redis ttl for local backfill and str env default * fix(guardrails): graceful block when routing configured but no session_id handle_sensitive_data_detection promised to raise only SensitiveDataRouteException or GuardrailRaisedException, but when routing was configured and the request had no session_id it let a ValueError from raise_sensitive_data_route_exception propagate, surfacing as an HTTP 500 instead of a block. Fall back to a graceful block in that case so the documented contract holds. * fix(guardrails): run remaining guardrails after sensitive-data reroute Defer the SensitiveDataRouteException until every guardrail in the pre-call loop has run, so downstream security guardrails are no longer skipped when an earlier guardrail triggers routing. The first reroute wins and a later guardrail that blocks still propagates. Also normalize on_sensitive_data to lowercase like sibling on_* config fields so case-insensitive values are accepted. * fix(guardrails): classify sensitive-data reroute as guardrail intervention * fix(guardrails): record sensitive-data reroute as prometheus intervention not error * fix(guardrails): record service span for routing guardrail and move case-normalizer to base params Drop the early continue so a guardrail that signals sensitive-data routing still emits its PROXY_PRE_CALL service span like every other callback. Move the lowercase normalizer onto BaseLitellmParams so on_sensitive_data is normalized consistently when BaseLitellmParams is constructed directly, matching the cross-field route->model validator that already lives on the base.
This commit is contained in:
parent
56aa55b991
commit
812a2217ca
@ -1062,3 +1062,37 @@ class GuardrailInterventionNormalStringError(
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class SensitiveDataRouteException(Exception):
|
||||
"""
|
||||
Exception raised when a guardrail detects sensitive data and wants to reroute the request.
|
||||
|
||||
Instead of blocking the request, this exception signals that the request should be
|
||||
routed to a different model (typically an on-premise model for data privacy).
|
||||
|
||||
The proxy catches this exception and:
|
||||
1. Reroutes the current request to the specified model
|
||||
2. When sticky_session_routing is True, stores the routing decision in session
|
||||
cache so all subsequent requests in the same session are routed to the same model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
route_to_model: str,
|
||||
session_id: str,
|
||||
guardrail_name: Optional[str] = None,
|
||||
detection_info: Optional[Dict[str, Any]] = None,
|
||||
message: Optional[str] = None,
|
||||
sticky_session_routing: bool = True,
|
||||
):
|
||||
self.route_to_model = route_to_model
|
||||
self.session_id = session_id
|
||||
self.guardrail_name = guardrail_name
|
||||
self.detection_info = detection_info or {}
|
||||
self.sticky_session_routing = sticky_session_routing
|
||||
self.message = (
|
||||
message
|
||||
or f"Sensitive data detected by {guardrail_name}. Routing to model: {route_to_model}"
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
@ -47,9 +47,29 @@ from litellm.exceptions import (
|
||||
BlockedPiiEntityError,
|
||||
GuardrailRaisedException,
|
||||
ModifyResponseException,
|
||||
SensitiveDataRouteException,
|
||||
)
|
||||
|
||||
|
||||
def get_session_id_from_request_data(request_data: Dict[str, Any]) -> Optional[str]:
|
||||
"""Extract session_id from request data (litellm_session_id or metadata)."""
|
||||
session_id = request_data.get("litellm_session_id")
|
||||
if session_id:
|
||||
return str(session_id)
|
||||
|
||||
metadata = request_data.get("metadata") or {}
|
||||
session_id = metadata.get("session_id")
|
||||
if session_id:
|
||||
return str(session_id)
|
||||
|
||||
litellm_metadata = request_data.get("litellm_metadata") or {}
|
||||
session_id = litellm_metadata.get("session_id")
|
||||
if session_id:
|
||||
return str(session_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class CustomGuardrail(CustomLogger):
|
||||
# If True, during_call runs async_moderation_hook instead of the unified apply_guardrail path.
|
||||
use_native_during_call_hook: ClassVar[bool] = False
|
||||
@ -68,6 +88,9 @@ class CustomGuardrail(CustomLogger):
|
||||
end_session_after_n_fails: Optional[int] = None,
|
||||
on_violation: Optional[str] = None,
|
||||
realtime_violation_message: Optional[str] = None,
|
||||
on_sensitive_data: Optional[str] = None,
|
||||
sensitive_data_route_to_model: Optional[str] = None,
|
||||
sticky_session_routing: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -83,6 +106,9 @@ class CustomGuardrail(CustomLogger):
|
||||
end_session_after_n_fails: For /v1/realtime sessions, end the session after this many violations
|
||||
on_violation: For /v1/realtime sessions, 'warn' or 'end_session'
|
||||
realtime_violation_message: Message the bot speaks aloud when a /v1/realtime guardrail fires
|
||||
on_sensitive_data: Action when sensitive data is detected. 'block' (default) or 'route'
|
||||
sensitive_data_route_to_model: Model to route to when on_sensitive_data='route'
|
||||
sticky_session_routing: When True, all subsequent requests in the session use the same model
|
||||
"""
|
||||
self.guardrail_name = guardrail_name
|
||||
self.supported_event_hooks = supported_event_hooks
|
||||
@ -96,6 +122,11 @@ class CustomGuardrail(CustomLogger):
|
||||
self.end_session_after_n_fails: Optional[int] = end_session_after_n_fails
|
||||
self.on_violation: Optional[str] = on_violation
|
||||
self.realtime_violation_message: Optional[str] = realtime_violation_message
|
||||
self.on_sensitive_data: Optional[str] = on_sensitive_data
|
||||
self.sensitive_data_route_to_model: Optional[str] = (
|
||||
sensitive_data_route_to_model
|
||||
)
|
||||
self.sticky_session_routing: bool = sticky_session_routing
|
||||
|
||||
if supported_event_hooks:
|
||||
## validate event_hook is in supported_event_hooks
|
||||
@ -167,6 +198,108 @@ class CustomGuardrail(CustomLogger):
|
||||
detection_info=detection_info,
|
||||
)
|
||||
|
||||
def raise_sensitive_data_route_exception(
|
||||
self,
|
||||
route_to_model: str,
|
||||
request_data: Dict[str, Any],
|
||||
detection_info: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Raise an exception to reroute the request to a different model.
|
||||
|
||||
Use this when sensitive data is detected and the guardrail is configured
|
||||
to route to an on-premise model instead of blocking.
|
||||
|
||||
The exception will reroute this request to the specified model. When
|
||||
sticky_session_routing is enabled (the default), it also stores the
|
||||
routing decision so subsequent requests in this session reuse the model.
|
||||
|
||||
Args:
|
||||
route_to_model: The model to route this request (and session) to
|
||||
request_data: The original request data dictionary
|
||||
detection_info: Optional non-sensitive detection metadata (e.g. matched
|
||||
entity types, rule ids, scores). This is surfaced in request metadata
|
||||
and logs, so it must not contain the raw detected sensitive values.
|
||||
|
||||
Raises:
|
||||
SensitiveDataRouteException: Always raises to trigger rerouting
|
||||
"""
|
||||
session_id = self._get_session_id_from_request_data(request_data)
|
||||
if not session_id:
|
||||
raise ValueError(
|
||||
"Cannot route sensitive data without a session_id. "
|
||||
"Ensure the request includes a session_id in metadata or headers."
|
||||
)
|
||||
|
||||
raise SensitiveDataRouteException(
|
||||
route_to_model=route_to_model,
|
||||
session_id=session_id,
|
||||
guardrail_name=self.guardrail_name,
|
||||
detection_info=detection_info,
|
||||
sticky_session_routing=self.sticky_session_routing,
|
||||
)
|
||||
|
||||
def _get_session_id_from_request_data(
|
||||
self, request_data: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
"""Extract session_id from request data."""
|
||||
return get_session_id_from_request_data(request_data)
|
||||
|
||||
def should_route_on_sensitive_data(self) -> bool:
|
||||
"""
|
||||
Returns True if this guardrail is configured to route requests
|
||||
to a different model when sensitive data is detected.
|
||||
"""
|
||||
return (
|
||||
self.on_sensitive_data == "route"
|
||||
and self.sensitive_data_route_to_model is not None
|
||||
)
|
||||
|
||||
def handle_sensitive_data_detection(
|
||||
self,
|
||||
request_data: Dict[str, Any],
|
||||
detection_info: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Handle sensitive data detection based on guardrail configuration.
|
||||
|
||||
If on_sensitive_data='route', raises SensitiveDataRouteException to reroute.
|
||||
Otherwise, raises GuardrailRaisedException to block. When routing is
|
||||
configured but the request carries no session_id, routing is not possible
|
||||
so the request falls back to a graceful block.
|
||||
|
||||
Args:
|
||||
request_data: The request data dictionary
|
||||
detection_info: Optional non-sensitive detection metadata. When routing,
|
||||
this is surfaced in request metadata and logs, so it must not contain
|
||||
the raw detected sensitive values.
|
||||
|
||||
Raises:
|
||||
SensitiveDataRouteException: When configured to route and a session_id is present
|
||||
GuardrailRaisedException: When configured to block, or when routing is
|
||||
configured but no session_id is available
|
||||
"""
|
||||
if self.should_route_on_sensitive_data():
|
||||
try:
|
||||
self.raise_sensitive_data_route_exception(
|
||||
route_to_model=self.sensitive_data_route_to_model, # type: ignore
|
||||
request_data=request_data,
|
||||
detection_info=detection_info,
|
||||
)
|
||||
except ValueError:
|
||||
raise GuardrailRaisedException(
|
||||
message=(
|
||||
f"Sensitive data detected by {self.guardrail_name} "
|
||||
"(routing skipped: request has no session_id)"
|
||||
),
|
||||
guardrail_name=self.guardrail_name,
|
||||
)
|
||||
else:
|
||||
raise GuardrailRaisedException(
|
||||
message=f"Sensitive data detected by {self.guardrail_name}",
|
||||
guardrail_name=self.guardrail_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
||||
"""
|
||||
@ -753,12 +886,20 @@ class CustomGuardrail(CustomLogger):
|
||||
Guardrails signal intentional blocks by raising:
|
||||
- GuardrailRaisedException (generic guardrail API, tool permission)
|
||||
- BlockedPiiEntityError (Presidio PII detection)
|
||||
- SensitiveDataRouteException (sensitive-data reroute to on-premise model)
|
||||
- HTTPException with status 400 (content policy violation)
|
||||
- ModifyResponseException (passthrough mode violation)
|
||||
"""
|
||||
if isinstance(e, ModifyResponseException):
|
||||
return True
|
||||
if isinstance(e, (GuardrailRaisedException, BlockedPiiEntityError)):
|
||||
if isinstance(
|
||||
e,
|
||||
(
|
||||
GuardrailRaisedException,
|
||||
BlockedPiiEntityError,
|
||||
SensitiveDataRouteException,
|
||||
),
|
||||
):
|
||||
return True
|
||||
if (
|
||||
HTTPException is not None
|
||||
|
||||
@ -2690,7 +2690,7 @@ class PrometheusLogger(CustomLogger):
|
||||
Args:
|
||||
guardrail_name: Name of the guardrail
|
||||
latency_seconds: Execution latency in seconds
|
||||
status: "success" or "error"
|
||||
status: "success", "error", or "intervened"
|
||||
error_type: Type of error if any, None otherwise
|
||||
hook_type: "pre_call", "during_call", or "post_call"
|
||||
"""
|
||||
|
||||
@ -2586,6 +2586,8 @@ class BaseLLMHTTPHandler:
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
headers.setdefault("Content-Type", "application/json")
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
@ -2676,6 +2678,8 @@ class BaseLLMHTTPHandler:
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
headers.setdefault("Content-Type", "application/json")
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
||||
@ -10,6 +10,7 @@ from .max_iterations_limiter import _PROXY_MaxIterationsHandler
|
||||
from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler
|
||||
from .parallel_request_limiter_v3 import _PROXY_MaxParallelRequestsHandler_v3
|
||||
from .responses_id_security import ResponsesIDSecurity
|
||||
from .sensitive_data_routing import _PROXY_SensitiveDataRoutingHandler
|
||||
|
||||
# List of all available hooks that can be enabled.
|
||||
# Defined before the enterprise import below so that any module re-imported
|
||||
@ -23,6 +24,7 @@ PROXY_HOOKS = {
|
||||
"litellm_skills": SkillsInjectionHook,
|
||||
"max_iterations_limiter": _PROXY_MaxIterationsHandler,
|
||||
"max_budget_per_session_limiter": _PROXY_MaxBudgetPerSessionHandler,
|
||||
"sensitive_data_routing": _PROXY_SensitiveDataRoutingHandler,
|
||||
}
|
||||
|
||||
## FEATURE FLAG HOOKS ##
|
||||
|
||||
206
litellm/proxy/hooks/sensitive_data_routing.py
Normal file
206
litellm/proxy/hooks/sensitive_data_routing.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""
|
||||
Sensitive Data Routing Hook for LiteLLM Proxy.
|
||||
|
||||
When a guardrail detects sensitive data and is configured with on_sensitive_data='route',
|
||||
this hook manages:
|
||||
1. Storing the routing decision (session_id -> model) in cache
|
||||
2. Checking incoming requests for existing routing overrides
|
||||
3. Applying sticky routing so all subsequent requests in a session go to the same model
|
||||
|
||||
Works across multiple proxy instances via DualCache (in-memory + Redis).
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import get_session_id_from_request_data
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
else:
|
||||
InternalUsageCache = Any
|
||||
|
||||
|
||||
SENSITIVE_ROUTING_CACHE_PREFIX = "sensitive_route"
|
||||
DEFAULT_SENSITIVE_ROUTING_TTL = 3600
|
||||
|
||||
|
||||
class _PROXY_SensitiveDataRoutingHandler(CustomLogger):
|
||||
"""
|
||||
Pre-call hook that checks for existing sensitive data routing overrides
|
||||
and applies them to incoming requests.
|
||||
|
||||
This hook runs early in the pre-call chain and modifies the request's
|
||||
model field if a routing override exists for the session.
|
||||
"""
|
||||
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.ttl = int(
|
||||
os.getenv(
|
||||
"LITELLM_SENSITIVE_ROUTING_TTL",
|
||||
str(DEFAULT_SENSITIVE_ROUTING_TTL),
|
||||
)
|
||||
)
|
||||
|
||||
def _make_cache_key(self, session_id: str, tenant: str) -> str:
|
||||
return f"{{{SENSITIVE_ROUTING_CACHE_PREFIX}:{tenant}:{session_id}}}:model"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_tenant(user_api_key_dict: Optional[UserAPIKeyAuth]) -> str:
|
||||
"""
|
||||
Identify the authenticated principal the routing override belongs to.
|
||||
|
||||
API-key auth is scoped by the hashed key. JWT (and other keyless) auth
|
||||
has no api_key, so fall back to a stable identity claim. Without this,
|
||||
every keyless caller would share the ``default`` namespace and could read
|
||||
or overwrite another principal's session routing.
|
||||
"""
|
||||
if user_api_key_dict is None:
|
||||
return "default"
|
||||
if user_api_key_dict.api_key:
|
||||
return user_api_key_dict.api_key
|
||||
principal = [
|
||||
f"{label}:{value}"
|
||||
for label, value in (
|
||||
("user", user_api_key_dict.user_id),
|
||||
("team", user_api_key_dict.team_id),
|
||||
("org", user_api_key_dict.org_id),
|
||||
)
|
||||
if value
|
||||
]
|
||||
return "|".join(principal) if principal else "default"
|
||||
|
||||
async def _get_routed_model(
|
||||
self, session_id: str, user_api_key_dict: Optional[UserAPIKeyAuth]
|
||||
) -> Optional[str]:
|
||||
"""Get the model this session should be routed to, if any."""
|
||||
cache_key = self._make_cache_key(
|
||||
session_id, self._resolve_tenant(user_api_key_dict)
|
||||
)
|
||||
|
||||
if self.internal_usage_cache.dual_cache.redis_cache is not None:
|
||||
try:
|
||||
result = await self.internal_usage_cache.dual_cache.redis_cache.async_get_cache(
|
||||
key=cache_key
|
||||
)
|
||||
if result is not None:
|
||||
routed_model = str(result)
|
||||
remaining_ttl = await self.internal_usage_cache.dual_cache.redis_cache.async_get_ttl(
|
||||
key=cache_key
|
||||
)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=routed_model,
|
||||
ttl=remaining_ttl if remaining_ttl is not None else self.ttl,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
return routed_model
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"SensitiveDataRoutingHandler: Redis GET failed, falling back to in-memory: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
result = await self.internal_usage_cache.async_get_cache(
|
||||
key=cache_key,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
if result is not None:
|
||||
return str(result)
|
||||
return None
|
||||
|
||||
async def set_session_routing(
|
||||
self,
|
||||
session_id: str,
|
||||
model: str,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
guardrail_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Store a routing override for a session.
|
||||
|
||||
Called by guardrails when they detect sensitive data and want to
|
||||
route the session to a specific model. The override is scoped to the
|
||||
requesting principal so sessions from different tenants cannot collide.
|
||||
"""
|
||||
cache_key = self._make_cache_key(
|
||||
session_id, self._resolve_tenant(user_api_key_dict)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"SensitiveDataRoutingHandler: Setting session routing session_id=%s model=%s guardrail=%s ttl=%s",
|
||||
session_id,
|
||||
model,
|
||||
guardrail_name,
|
||||
self.ttl,
|
||||
)
|
||||
|
||||
if self.internal_usage_cache.dual_cache.redis_cache is not None:
|
||||
try:
|
||||
await self.internal_usage_cache.dual_cache.redis_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=model,
|
||||
ttl=self.ttl,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"SensitiveDataRoutingHandler: Redis SET failed, falling back to in-memory: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=model,
|
||||
ttl=self.ttl,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Before each LLM call, check if this session has a routing override.
|
||||
If so, modify the request's model field.
|
||||
"""
|
||||
session_id = get_session_id_from_request_data(data)
|
||||
if session_id is None:
|
||||
return None
|
||||
|
||||
routed_model = await self._get_routed_model(session_id, user_api_key_dict)
|
||||
if routed_model is None:
|
||||
return None
|
||||
|
||||
original_model = data.get("model")
|
||||
if original_model == routed_model:
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"SensitiveDataRoutingHandler: Applying session routing override "
|
||||
"session_id=%s original_model=%s routed_model=%s",
|
||||
session_id,
|
||||
original_model,
|
||||
routed_model,
|
||||
)
|
||||
|
||||
data["model"] = routed_model
|
||||
|
||||
metadata = data.get("metadata") or {}
|
||||
metadata["sensitive_data_routing_applied"] = True
|
||||
metadata["sensitive_data_routing_original_model"] = original_model
|
||||
data["metadata"] = metadata
|
||||
|
||||
return data
|
||||
@ -89,11 +89,14 @@ from litellm._logging import _redact_string, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||
from litellm.caching.caching import DualCache, RedisCache
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
from litellm.exceptions import RejectedRequestError, SensitiveDataRouteException
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
ModifyResponseException,
|
||||
)
|
||||
from litellm.proxy.hooks.sensitive_data_routing import (
|
||||
_PROXY_SensitiveDataRoutingHandler,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
|
||||
@ -1151,6 +1154,9 @@ class ProxyLogging:
|
||||
response=response, data=data, call_type=call_type
|
||||
)
|
||||
|
||||
except SensitiveDataRouteException:
|
||||
status = "intervened"
|
||||
raise
|
||||
except Exception as e:
|
||||
status = "error"
|
||||
error_type = type(e).__name__
|
||||
@ -1460,47 +1466,57 @@ class ProxyLogging:
|
||||
self._process_guardrail_metadata(data)
|
||||
return data
|
||||
|
||||
deferred_route_exc: Optional[SensitiveDataRouteException] = None
|
||||
for _callback in caps.resolved_callbacks:
|
||||
start_time = time.time()
|
||||
if isinstance(_callback, CustomGuardrail) and data is not None:
|
||||
# Skip guardrails managed by a pipeline
|
||||
if (
|
||||
_callback.guardrail_name
|
||||
and _callback.guardrail_name in pipeline_managed
|
||||
):
|
||||
continue
|
||||
try:
|
||||
if isinstance(_callback, CustomGuardrail) and data is not None:
|
||||
# Skip guardrails managed by a pipeline
|
||||
if (
|
||||
_callback.guardrail_name
|
||||
and _callback.guardrail_name in pipeline_managed
|
||||
):
|
||||
continue
|
||||
|
||||
result = await self._process_guardrail_callback(
|
||||
callback=_callback,
|
||||
data=data, # type: ignore
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type=call_type,
|
||||
event_type=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
data = result
|
||||
|
||||
elif (
|
||||
_callback is not None
|
||||
and isinstance(_callback, CustomLogger)
|
||||
and "async_pre_call_hook" in vars(_callback.__class__)
|
||||
and _callback.__class__.async_pre_call_hook
|
||||
!= CustomLogger.async_pre_call_hook
|
||||
):
|
||||
if call_type == "call_mcp_tool" and user_api_key_dict is None:
|
||||
continue
|
||||
|
||||
response = await _callback.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=self.call_details["user_api_key_cache"],
|
||||
data=data, # type: ignore
|
||||
call_type=call_type, # type: ignore
|
||||
)
|
||||
if response is not None:
|
||||
data = await self.process_pre_call_hook_response(
|
||||
response=response, data=data, call_type=call_type
|
||||
result = await self._process_guardrail_callback(
|
||||
callback=_callback,
|
||||
data=data, # type: ignore
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type=call_type,
|
||||
event_type=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
data = result
|
||||
|
||||
elif (
|
||||
_callback is not None
|
||||
and isinstance(_callback, CustomLogger)
|
||||
and "async_pre_call_hook" in vars(_callback.__class__)
|
||||
and _callback.__class__.async_pre_call_hook
|
||||
!= CustomLogger.async_pre_call_hook
|
||||
):
|
||||
if call_type == "call_mcp_tool" and user_api_key_dict is None:
|
||||
continue
|
||||
|
||||
response = await _callback.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=self.call_details["user_api_key_cache"],
|
||||
data=data, # type: ignore
|
||||
call_type=call_type, # type: ignore
|
||||
)
|
||||
if response is not None:
|
||||
data = await self.process_pre_call_hook_response(
|
||||
response=response, data=data, call_type=call_type
|
||||
)
|
||||
except SensitiveDataRouteException as e:
|
||||
# Defer the reroute until remaining guardrails have run so later
|
||||
# security checks are not skipped; the first reroute wins and a
|
||||
# later guardrail that blocks still propagates. Fall through to the
|
||||
# service-span recording below so the triggering guardrail is still
|
||||
# timed like every other callback.
|
||||
if deferred_route_exc is None:
|
||||
deferred_route_exc = e
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
@ -1516,13 +1532,76 @@ class ProxyLogging:
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
if deferred_route_exc is not None and data is not None:
|
||||
data = await self._handle_sensitive_data_route_exception(
|
||||
deferred_route_exc, data, user_api_key_dict
|
||||
)
|
||||
|
||||
if data is not None:
|
||||
self._process_guardrail_metadata(data)
|
||||
|
||||
return data
|
||||
except SensitiveDataRouteException as e:
|
||||
data = await self._handle_sensitive_data_route_exception(
|
||||
e, data, user_api_key_dict
|
||||
)
|
||||
if data is not None:
|
||||
self._process_guardrail_metadata(data)
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _handle_sensitive_data_route_exception(
|
||||
self,
|
||||
exc: SensitiveDataRouteException,
|
||||
data: Optional[dict],
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Handle SensitiveDataRouteException by rerouting the current request to
|
||||
the target model and, when sticky_session_routing is enabled, persisting
|
||||
the session override so subsequent requests reuse the same model.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"SensitiveDataRouteException caught: session_id=%s route_to_model=%s guardrail=%s sticky=%s",
|
||||
exc.session_id,
|
||||
exc.route_to_model,
|
||||
exc.guardrail_name,
|
||||
exc.sticky_session_routing,
|
||||
)
|
||||
|
||||
if exc.sticky_session_routing:
|
||||
sensitive_routing_hook = self.get_proxy_hook("sensitive_data_routing")
|
||||
if isinstance(sensitive_routing_hook, _PROXY_SensitiveDataRoutingHandler):
|
||||
await sensitive_routing_hook.set_session_routing(
|
||||
session_id=exc.session_id,
|
||||
model=exc.route_to_model,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
guardrail_name=exc.guardrail_name,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"SensitiveDataRouteException requested sticky routing for session_id=%s "
|
||||
"but the 'sensitive_data_routing' hook is not registered. Only this request "
|
||||
"will be rerouted; subsequent requests will not be sticky.",
|
||||
exc.session_id,
|
||||
)
|
||||
|
||||
original_model = data.get("model")
|
||||
data["model"] = exc.route_to_model
|
||||
|
||||
metadata = data.get("metadata") or {}
|
||||
metadata["sensitive_data_routing_applied"] = True
|
||||
metadata["sensitive_data_routing_original_model"] = original_model
|
||||
metadata["sensitive_data_routing_guardrail"] = exc.guardrail_name
|
||||
metadata["sensitive_data_routing_detection_info"] = exc.detection_info
|
||||
data["metadata"] = metadata
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def _run_guardrail_task_with_enrichment(
|
||||
callback: Any, coro: Awaitable[Any]
|
||||
|
||||
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.akto import (
|
||||
@ -771,6 +771,58 @@ class BaseLitellmParams(
|
||||
),
|
||||
)
|
||||
|
||||
on_sensitive_data: Optional[Literal["block", "route"]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Action to take when sensitive data is detected. "
|
||||
"'block' raises an exception (default behavior). "
|
||||
"'route' reroutes the request to the model specified in sensitive_data_route_to_model."
|
||||
),
|
||||
)
|
||||
|
||||
sensitive_data_route_to_model: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Model to route requests to when sensitive data is detected and on_sensitive_data='route'. "
|
||||
"This is typically an on-premise model for data privacy. "
|
||||
"The routing decision persists for the entire session."
|
||||
),
|
||||
)
|
||||
|
||||
sticky_session_routing: Optional[bool] = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"When True (default), after sensitive data is detected and routed, all subsequent "
|
||||
"requests in the same session will continue routing to the same model."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator(
|
||||
"mode",
|
||||
"default_action",
|
||||
"on_disallowed_action",
|
||||
"unreachable_fallback",
|
||||
"on_sensitive_data",
|
||||
mode="before",
|
||||
check_fields=False,
|
||||
)
|
||||
@classmethod
|
||||
def normalize_lowercase(cls, v):
|
||||
"""Normalize string and list fields to lowercase for ALL guardrail types."""
|
||||
if isinstance(v, str):
|
||||
return v.lower()
|
||||
if isinstance(v, list):
|
||||
return [x.lower() if isinstance(x, str) else x for x in v]
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_sensitive_data_routing(self) -> "BaseLitellmParams":
|
||||
if self.on_sensitive_data == "route" and not self.sensitive_data_route_to_model:
|
||||
raise ValueError(
|
||||
"sensitive_data_route_to_model must be set when on_sensitive_data='route'"
|
||||
)
|
||||
return self
|
||||
|
||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
||||
|
||||
|
||||
@ -811,23 +863,6 @@ class LitellmParams(
|
||||
description="When to apply the guardrail (pre_call, post_call, during_call, logging_only)"
|
||||
)
|
||||
|
||||
@field_validator(
|
||||
"mode",
|
||||
"default_action",
|
||||
"on_disallowed_action",
|
||||
"unreachable_fallback",
|
||||
mode="before",
|
||||
check_fields=False,
|
||||
)
|
||||
@classmethod
|
||||
def normalize_lowercase(cls, v):
|
||||
"""Normalize string and list fields to lowercase for ALL guardrail types."""
|
||||
if isinstance(v, str):
|
||||
return v.lower()
|
||||
if isinstance(v, list):
|
||||
return [x.lower() if isinstance(x, str) else x for x in v]
|
||||
return v
|
||||
|
||||
@field_validator("timeout", mode="before", check_fields=False)
|
||||
@classmethod
|
||||
def coerce_timeout(cls, v):
|
||||
|
||||
@ -1249,3 +1249,47 @@ class TestCustomGuardrailSpendLogMatchRedaction:
|
||||
slg = request_data["metadata"]["standard_logging_guardrail_information"][0]
|
||||
assert slg["guardrail_response"]["filters"][0]["regex"] == "[REDACTED]"
|
||||
assert raw["filters"][0]["regex"] == r"\d{3}-\d{2}-\d{4}"
|
||||
|
||||
|
||||
class TestGuardrailInterventionClassification:
|
||||
"""A routing decision is a deliberate guardrail intervention, not a failure."""
|
||||
|
||||
def test_sensitive_data_route_exception_is_intervention(self):
|
||||
from litellm.exceptions import SensitiveDataRouteException
|
||||
|
||||
exc = SensitiveDataRouteException(
|
||||
route_to_model="on-prem-model",
|
||||
session_id="sess-1",
|
||||
guardrail_name="pii-rail",
|
||||
)
|
||||
assert CustomGuardrail._is_guardrail_intervention(exc) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routing_logged_as_intervened_not_failed(self):
|
||||
from litellm.exceptions import SensitiveDataRouteException
|
||||
from litellm.integrations.custom_guardrail import log_guardrail_information
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
class RoutingGuardrail(CustomGuardrail):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
guardrail_name="pii-rail",
|
||||
event_hook=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_pre_call_hook(self, data, **kwargs):
|
||||
raise SensitiveDataRouteException(
|
||||
route_to_model="on-prem-model",
|
||||
session_id="sess-1",
|
||||
guardrail_name=self.guardrail_name,
|
||||
)
|
||||
|
||||
guardrail = RoutingGuardrail()
|
||||
request_data: dict = {"metadata": {}}
|
||||
|
||||
with pytest.raises(SensitiveDataRouteException):
|
||||
await guardrail.async_pre_call_hook(data=request_data)
|
||||
|
||||
slg = request_data["metadata"]["standard_logging_guardrail_information"][0]
|
||||
assert slg["guardrail_status"] == "guardrail_intervened"
|
||||
|
||||
@ -564,6 +564,46 @@ def test_sync_delete_responses_omits_body_for_azure():
|
||||
)
|
||||
|
||||
|
||||
def _content_type(headers: dict) -> str:
|
||||
for key, value in headers.items():
|
||||
if key.lower() == "content-type":
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
def test_async_delete_responses_sets_json_content_type():
|
||||
"""OpenAI rejects a responses DELETE with no Content-Type by treating it as
|
||||
application/octet-stream. The handler must declare application/json."""
|
||||
captured: dict = {}
|
||||
fake_async_delete, _ = _build_delete_response_mock(captured)
|
||||
|
||||
async def run():
|
||||
with patch.object(AsyncHTTPHandler, "delete", new=fake_async_delete):
|
||||
await litellm.adelete_responses(
|
||||
response_id="resp_xyz",
|
||||
custom_llm_provider="openai",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert _content_type(captured["headers"]) == "application/json"
|
||||
|
||||
|
||||
def test_sync_delete_responses_sets_json_content_type():
|
||||
captured: dict = {}
|
||||
_, fake_sync_delete = _build_delete_response_mock(captured)
|
||||
|
||||
with patch.object(HTTPHandler, "delete", new=fake_sync_delete):
|
||||
litellm.delete_responses(
|
||||
response_id="resp_xyz",
|
||||
custom_llm_provider="openai",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert _content_type(captured["headers"]) == "application/json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parity tests: request-body is serialized once and reused for the wire.
|
||||
# (_async_post_anthropic_messages_with_http_error_retry)
|
||||
|
||||
1036
tests/test_litellm/proxy/hooks/test_sensitive_data_routing.py
Normal file
1036
tests/test_litellm/proxy/hooks/test_sensitive_data_routing.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,7 +3,9 @@ Test case normalization in LitellmParams for all guardrail types
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from litellm.types.guardrails import LitellmParams
|
||||
from pydantic import ValidationError
|
||||
|
||||
from litellm.types.guardrails import BaseLitellmParams, LitellmParams
|
||||
|
||||
|
||||
class TestLitellmParamsCaseNormalization:
|
||||
@ -89,3 +91,66 @@ class TestLitellmParamsCaseNormalization:
|
||||
)
|
||||
assert params.on_disallowed_action in ["block", "rewrite"]
|
||||
assert params.on_disallowed_action.islower()
|
||||
|
||||
|
||||
class TestSensitiveDataRoutingValidation:
|
||||
"""on_sensitive_data='route' requires a target model to be set"""
|
||||
|
||||
def test_route_with_target_model_is_valid(self):
|
||||
params = LitellmParams(
|
||||
guardrail="presidio",
|
||||
mode="pre_call",
|
||||
on_sensitive_data="route",
|
||||
sensitive_data_route_to_model="on-prem-model",
|
||||
)
|
||||
assert params.on_sensitive_data == "route"
|
||||
assert params.sensitive_data_route_to_model == "on-prem-model"
|
||||
|
||||
def test_route_without_target_model_raises(self):
|
||||
with pytest.raises(ValidationError, match="sensitive_data_route_to_model"):
|
||||
LitellmParams(
|
||||
guardrail="presidio",
|
||||
mode="pre_call",
|
||||
on_sensitive_data="route",
|
||||
)
|
||||
|
||||
def test_base_params_route_without_target_model_raises(self):
|
||||
with pytest.raises(ValidationError, match="sensitive_data_route_to_model"):
|
||||
BaseLitellmParams(on_sensitive_data="route")
|
||||
|
||||
def test_base_params_normalize_on_sensitive_data_case(self):
|
||||
params = BaseLitellmParams(
|
||||
on_sensitive_data="Route",
|
||||
sensitive_data_route_to_model="on-prem-model",
|
||||
)
|
||||
assert params.on_sensitive_data == "route"
|
||||
|
||||
def test_base_params_capitalized_route_without_target_model_raises(self):
|
||||
with pytest.raises(ValidationError, match="sensitive_data_route_to_model"):
|
||||
BaseLitellmParams(on_sensitive_data="ROUTE")
|
||||
|
||||
def test_block_without_target_model_is_valid(self):
|
||||
params = LitellmParams(
|
||||
guardrail="presidio",
|
||||
mode="pre_call",
|
||||
on_sensitive_data="block",
|
||||
)
|
||||
assert params.on_sensitive_data == "block"
|
||||
assert params.sensitive_data_route_to_model is None
|
||||
|
||||
def test_on_sensitive_data_is_case_normalized(self):
|
||||
params = LitellmParams(
|
||||
guardrail="presidio",
|
||||
mode="pre_call",
|
||||
on_sensitive_data="Route",
|
||||
sensitive_data_route_to_model="on-prem-model",
|
||||
)
|
||||
assert params.on_sensitive_data == "route"
|
||||
|
||||
def test_on_sensitive_data_uppercase_block_normalized(self):
|
||||
params = LitellmParams(
|
||||
guardrail="presidio",
|
||||
mode="pre_call",
|
||||
on_sensitive_data="BLOCK",
|
||||
)
|
||||
assert params.on_sensitive_data == "block"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user