fix(unified_guardrail.py): correctly map a v1/messages call to the anthropic unified guardrail (#17424)

* fix(unified_guardrail.py): correctly map a v1/messages call to the anthropic unified guardrail

* fix: add more rigorous call type checks

* fix(anthropic_endpoints/endpoints.py): initialize logging object at the beginning of endpoint

ensures call id + trace id are emitted to guardrail api

* feat(anthropic/chat/guardrail_translation): support streaming guardrails

sample on every 5 chunks

* fix(openai/chat/guardrail_translation): support openai streaming guardrails

* fix: initial commit fixing output guardrails for responses api

* feat(openai/responses/guardrail_translation): handler.py - fix output checks on responses api

* fix(openai/responses/guardrail_translation/handler.py): ensure responses api guardrails work on streaming

* test: update tests

* test: update tests

* test: update tests

* fix(bedrock_guardrails.py): fix post call streaming iterator logic

* fix: fix return

* fix(bedrock_guardrails.py): fix
This commit is contained in:
Krish Dholakia 2025-12-03 20:54:56 -08:00 committed by GitHub
parent a711b63b06
commit be0530a6b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1211 additions and 897 deletions

View File

@ -429,6 +429,8 @@ class LitellmBasicGuardrailRequest(BaseModel):
request_data: Dict[str, Any] = Field(default_factory=dict)
additional_provider_specific_params: Dict[str, Any] = Field(default_factory=dict)
input_type: Literal["request", "response"]
litellm_call_id: Optional[str] = None
litellm_trace_id: Optional[str] = None
class LitellmBasicGuardrailResponse(BaseModel):

View File

@ -95,6 +95,7 @@ from litellm.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
ProviderConfigManager,
TextCompletionResponse,
TranscriptionResponse,
@ -654,7 +655,9 @@ def _infer_call_type(
if completion_response is None:
return None
if isinstance(completion_response, ModelResponse):
if isinstance(completion_response, ModelResponse) or isinstance(
completion_response, ModelResponseStream
):
return "completion"
elif isinstance(completion_response, EmbeddingResponse):
return "embedding"
@ -1057,29 +1060,33 @@ def completion_cost( # noqa: PLR0915
number_of_queries = len(query)
elif query is not None:
number_of_queries = 1
search_model = model or ""
if custom_llm_provider and "/" not in search_model:
# If model is like "tavily-search", construct "tavily/search" for cost lookup
search_model = f"{custom_llm_provider}/search"
prompt_cost, completion_cost_result = search_provider_cost_per_query(
model=search_model,
custom_llm_provider=custom_llm_provider,
number_of_queries=number_of_queries,
optional_params=optional_params,
prompt_cost, completion_cost_result = (
search_provider_cost_per_query(
model=search_model,
custom_llm_provider=custom_llm_provider,
number_of_queries=number_of_queries,
optional_params=optional_params,
)
)
# Return the total cost (prompt_cost + completion_cost, but for search it's just prompt_cost)
_final_cost = prompt_cost + completion_cost_result
# Apply discount
original_cost = _final_cost
_final_cost, discount_percent, discount_amount = _apply_cost_discount(
base_cost=_final_cost,
custom_llm_provider=custom_llm_provider,
_final_cost, discount_percent, discount_amount = (
_apply_cost_discount(
base_cost=_final_cost,
custom_llm_provider=custom_llm_provider,
)
)
# Store cost breakdown in logging object if available
_store_cost_breakdown_in_logging_obj(
litellm_logging_obj=litellm_logging_obj,
@ -1091,7 +1098,7 @@ def completion_cost( # noqa: PLR0915
discount_percent=discount_percent,
discount_amount=discount_amount,
)
return _final_cost
elif call_type == CallTypes.arealtime.value and isinstance(
completion_response, LiteLLMRealtimeStreamLoggingObject

View File

@ -12,6 +12,7 @@ Pattern Overview:
4. Apply guardrail responses back to the original structure
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from litellm._logging import verbose_proxy_logger
@ -246,6 +247,115 @@ class AnthropicMessagesHandler(BaseTranslation):
return response
async def process_output_streaming_response(
self,
responses_so_far: List[Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> List[Any]:
"""
Process output streaming response by applying guardrails to text content.
Get the string so far, check the apply guardrail to the string so far, and return the list of responses so far.
"""
string_so_far = self.get_streaming_string_so_far(responses_so_far)
_, _ = (
await guardrail_to_apply.apply_guardrail( # allow rejecting the response, if invalid
texts=[string_so_far],
request_data={},
input_type="response",
logging_obj=litellm_logging_obj,
images=None,
)
)
return responses_so_far
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
"""
Parse streaming responses and extract accumulated text content.
Handles two formats:
1. Raw bytes in SSE (Server-Sent Events) format from Anthropic API
2. Parsed dict objects (for backwards compatibility)
SSE format example:
b'event: content_block_delta\\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" curious"}}\\n\\n'
Dict format example:
{
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": " curious"
}
}
"""
text_so_far = ""
for response in responses_so_far:
# Handle raw bytes in SSE format
if isinstance(response, bytes):
text_so_far += self._extract_text_from_sse(response)
# Handle already-parsed dict format
elif isinstance(response, dict):
delta = response.get("delta") if response.get("delta") else None
if delta and delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
text_so_far += text
return text_so_far
def _extract_text_from_sse(self, sse_bytes: bytes) -> str:
"""
Extract text content from Server-Sent Events (SSE) format.
Args:
sse_bytes: Raw bytes in SSE format
Returns:
Accumulated text from all content_block_delta events
"""
text = ""
try:
# Decode bytes to string
sse_string = sse_bytes.decode("utf-8")
# Split by double newline to get individual events
events = sse_string.split("\n\n")
for event in events:
if not event.strip():
continue
# Parse event lines
lines = event.strip().split("\n")
event_type = None
data_line = None
for line in lines:
if line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
data_line = line[5:].strip()
# Only process content_block_delta events
if event_type == "content_block_delta" and data_line:
try:
data = json.loads(data_line)
delta = data.get("delta", {})
if delta.get("type") == "text_delta":
text += delta.get("text", "")
except json.JSONDecodeError:
verbose_proxy_logger.warning(
f"Failed to parse JSON from SSE data: {data_line}"
)
except Exception as e:
verbose_proxy_logger.error(f"Error extracting text from SSE: {e}")
return text
def _has_text_content(self, response: "AnthropicMessagesResponse") -> bool:
"""
Check if response has any text content to process.

View File

@ -436,9 +436,7 @@ class AnthropicChatCompletion(BaseLLM):
else:
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client(
params={"timeout": timeout}
)
client = _get_httpx_client(params={"timeout": timeout})
else:
client = client
@ -528,9 +526,7 @@ class ModelResponseIterator:
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
)
def _content_block_delta_helper(
self, chunk: dict
) -> Tuple[
def _content_block_delta_helper(self, chunk: dict) -> Tuple[
str,
Optional[ChatCompletionToolCallChunk],
List[Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]],

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
@ -87,7 +87,7 @@ class BaseTranslation(ABC):
async def process_output_streaming_response(
self,
response: Any,
responses_so_far: List[Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
@ -97,4 +97,4 @@ class BaseTranslation(ABC):
Optional to override in subclasses.
"""
return response
return responses_so_far

View File

@ -243,48 +243,122 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
async def process_output_streaming_response(
self,
response: "ModelResponseStream",
responses_so_far: List["ModelResponseStream"],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
) -> List["ModelResponseStream"]:
"""
Process output streaming response by applying guardrails to text content.
Process output streaming responses by applying guardrails to text content.
Args:
response: LiteLLM ModelResponseStream object
responses_so_far: List of LiteLLM ModelResponseStream objects
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata to pass to guardrails
Returns:
Modified response with guardrail applied to content
Modified list of responses with guardrail applied to content
Response Format Support:
- String content: choice.message.content = "text here"
- List content: choice.message.content = [{"type": "text", "text": "text here"}, ...]
"""
# Step 0: Check if response has any text content to process
if not self._has_text_content(response):
return response
# Step 0: Check if any response has text content to process
has_any_text_content = False
for response in responses_so_far:
if self._has_text_content(response):
has_any_text_content = True
break
if not has_any_text_content:
verbose_proxy_logger.warning(
"OpenAI Chat Completions: No text content in streaming responses, skipping guardrail"
)
return responses_so_far
# Step 1: Combine all streaming chunks into complete text per choice
# For streaming, we need to concatenate all delta.content across all chunks
# Key: (choice_idx, content_idx), Value: combined text
combined_texts: Dict[Tuple[int, Optional[int]], str] = {}
for response_idx, response in enumerate(responses_so_far):
for choice_idx, choice in enumerate(response.choices):
if isinstance(choice, litellm.StreamingChoices):
content = choice.delta.content
elif isinstance(choice, litellm.Choices):
content = choice.message.content
else:
continue
if content is None:
continue
if isinstance(content, str):
# String content - accumulate for this choice
key = (choice_idx, None)
if key not in combined_texts:
combined_texts[key] = ""
combined_texts[key] += content
elif isinstance(content, list):
# List content - accumulate for each content item
for content_idx, content_item in enumerate(content):
text_str = content_item.get("text")
if text_str:
key = (choice_idx, content_idx)
if key not in combined_texts:
combined_texts[key] = ""
combined_texts[key] += text_str
# Step 2: Create lists for guardrail processing
texts_to_check: List[str] = []
images_to_check: List[str] = []
task_mappings: List[Tuple[int, Optional[int]]] = []
# Track (choice_index, content_index) for each text
# Track (choice_index, content_index) for each combined text
# Step 1: Extract all text content and images from response choices
for choice_idx, choice in enumerate(response.choices):
for (choice_idx, content_idx), combined_text in combined_texts.items():
texts_to_check.append(combined_text)
task_mappings.append((choice_idx, content_idx))
self._extract_output_text_and_images(
choice=choice,
choice_idx=choice_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
# Step 3: Apply guardrail to all combined texts in batch
if texts_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"responses": responses_so_far}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
guardrailed_texts, guardrailed_images = (
await guardrail_to_apply.apply_guardrail(
texts=texts_to_check,
request_data=request_data,
input_type="response",
images=images_to_check if images_to_check else None,
logging_obj=litellm_logging_obj,
)
)
# Step 4: Apply guardrailed text back to all streaming chunks
# For each choice, replace the combined text across all chunks
await self._apply_guardrail_responses_to_output_streaming(
responses=responses_so_far,
guardrailed_texts=guardrailed_texts,
task_mappings=task_mappings,
)
verbose_proxy_logger.debug(
"OpenAI Chat Completions: Processed output streaming responses: %s",
responses_so_far,
)
return responses_so_far
def _has_text_content(
self, response: Union["ModelResponse", "ModelResponseStream"]
) -> bool:
@ -304,10 +378,8 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
return True
elif isinstance(response, ModelResponseStream):
for choice in response.choices:
if isinstance(choice, litellm.Choices):
if choice.message.content and isinstance(
choice.message.content, str
):
if isinstance(choice, litellm.StreamingChoices):
if choice.delta.content and isinstance(choice.delta.content, str):
return True
return False
@ -394,3 +466,79 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
][
"text"
] = guardrail_response
async def _apply_guardrail_responses_to_output_streaming(
self,
responses: List["ModelResponseStream"],
guardrailed_texts: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Apply guardrail responses back to output streaming responses.
For streaming responses, the guardrailed text (which is the combined text from all chunks)
is placed in the first chunk, and subsequent chunks are cleared.
Args:
responses: List of ModelResponseStream objects to modify
guardrailed_texts: List of guardrailed text responses (combined from all chunks)
task_mappings: List of tuples (choice_idx, content_idx)
Override this method to customize how responses are applied to streaming responses.
"""
# Build a mapping of what guardrailed text to use for each (choice_idx, content_idx)
guardrail_map: Dict[Tuple[int, Optional[int]], str] = {}
for task_idx, guardrail_response in enumerate(guardrailed_texts):
mapping = task_mappings[task_idx]
choice_idx = cast(int, mapping[0])
content_idx_optional = cast(Optional[int], mapping[1])
guardrail_map[(choice_idx, content_idx_optional)] = guardrail_response
# Track which choices we've already set the guardrailed text for
# Key: (choice_idx, content_idx), Value: boolean (True if already set)
already_set: Dict[Tuple[int, Optional[int]], bool] = {}
# Iterate through all responses and update content
for response_idx, response in enumerate(responses):
for choice_idx_in_response, choice in enumerate(response.choices):
if isinstance(choice, litellm.StreamingChoices):
content = choice.delta.content
elif isinstance(choice, litellm.Choices):
content = choice.message.content
else:
continue
if content is None:
continue
if isinstance(content, str):
# String content
key = (choice_idx_in_response, None)
if key in guardrail_map:
if key not in already_set:
# First chunk - set the complete guardrailed text
if isinstance(choice, litellm.StreamingChoices):
choice.delta.content = guardrail_map[key]
elif isinstance(choice, litellm.Choices):
choice.message.content = guardrail_map[key]
already_set[key] = True
else:
# Subsequent chunks - clear the content
if isinstance(choice, litellm.StreamingChoices):
choice.delta.content = ""
elif isinstance(choice, litellm.Choices):
choice.message.content = ""
elif isinstance(content, list):
# List content - handle each content item
for content_idx, content_item in enumerate(content):
if "text" in content_item:
key = (choice_idx_in_response, content_idx)
if key in guardrail_map:
if key not in already_set:
# First chunk - set the complete guardrailed text
content_item["text"] = guardrail_map[key]
already_set[key] = True
else:
# Subsequent chunks - clear the text
content_item["text"] = ""

View File

@ -30,6 +30,8 @@ Output: response.output is List[GenericResponseOutputItem] where each has:
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast
from openai import BaseModel
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.responses.main import GenericResponseOutputItem, OutputText
@ -275,6 +277,32 @@ class OpenAIResponsesHandler(BaseTranslation):
return response
async def process_output_streaming_response(
self,
responses_so_far: List[Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> List[Any]:
"""
Process output streaming response by applying guardrails to text content.
"""
string_so_far = self.get_streaming_string_so_far(responses_so_far)
guardrailed_text, _ = await guardrail_to_apply.apply_guardrail(
texts=[string_so_far],
request_data={},
input_type="response",
logging_obj=litellm_logging_obj,
images=None,
)
return responses_so_far
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
"""
Get the string so far from the responses so far.
"""
return "".join([response.get("text", "") for response in responses_so_far])
def _has_text_content(self, response: "ResponsesAPIResponse") -> bool:
"""
Check if response has any text content to process.
@ -285,6 +313,17 @@ class OpenAIResponsesHandler(BaseTranslation):
return False
for output_item in response.output:
if isinstance(output_item, BaseModel):
try:
generic_response_output_item = (
GenericResponseOutputItem.model_validate(
output_item.model_dump()
)
)
if generic_response_output_item.content:
output_item = generic_response_output_item
except Exception:
continue
if isinstance(output_item, (GenericResponseOutputItem, dict)):
content = (
output_item.content
@ -296,9 +335,11 @@ class OpenAIResponsesHandler(BaseTranslation):
# Check if it's an OutputText with text
if isinstance(content_item, OutputText):
if content_item.text:
return True
elif isinstance(content_item, dict):
if content_item.get("text"):
return True
return False
@ -316,8 +357,16 @@ class OpenAIResponsesHandler(BaseTranslation):
Override this method to customize text/image extraction logic.
"""
# Handle both GenericResponseOutputItem and dict
if isinstance(output_item, GenericResponseOutputItem):
content = output_item.content
content: Optional[Union[List[OutputText], List[dict]]] = None
if isinstance(output_item, BaseModel):
try:
generic_response_output_item = GenericResponseOutputItem.model_validate(
output_item.model_dump()
)
if generic_response_output_item.content:
content = generic_response_output_item.content
except Exception:
return
elif isinstance(output_item, dict):
content = output_item.get("content", [])
else:

View File

@ -3,23 +3,19 @@ model_list:
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
- model_name: claude-sonnet-4-5-20250929
litellm_params:
model: anthropic/claude-sonnet-4-5-20250929
- model_name: gpt-4.1-mini
litellm_params:
model: openai/gpt-4.1-mini
guardrails:
# - guardrail_name: model-armor-shield
# litellm_params:
# guardrail: model_armor
# mode: "post_call" # Run on both input and output
# template_id: "test-prompt-template" # Required: Your Model Armor template ID
# project_id: "test-vector-store-db" # Your GCP project ID
# location: "us" # GCP location (default: us-central1)
# mask_request_content: true # Enable request content masking
# mask_response_content: true # Enable response content masking
# fail_on_error: true # Fail request if Model Armor errors (default: true)
# default_on: true # Run by default for all requests
- guardrail_name: generic-guardrail
litellm_params:
guardrail: generic_guardrail_api
mode: [pre_call, post_call]
mode: ["pre_call", "post_call", "during_call"]
headers:
Authorization: Bearer mock-bedrock-token-12345
api_base: http://localhost:8080

View File

@ -2,20 +2,14 @@
Unified /v1/messages endpoint - (Anthropic Spec)
"""
import asyncio
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, Response
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
create_streaming_response,
)
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.types.utils import TokenCountResponse
router = APIRouter()
@ -49,169 +43,28 @@ async def anthropic_response( # noqa: PLR0915
version,
)
request_data = await _read_request_body(request=request)
data: dict = {**request_data}
data = await _read_request_body(request=request)
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
try:
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data.get("model", None) # default passed in http request
)
if user_model:
data["model"] = user_model
data = await add_litellm_data_to_request(
data=data, # type: ignore
result = await base_llm_response_processor.base_process_llm_request(
request=request,
general_settings=general_settings,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
version=version,
route_type="anthropic_messages",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=None,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type=CallTypes.anthropic_messages.value
)
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
data=data,
user_api_key_dict=user_api_key_dict,
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
route_type="anthropic_messages" # type: ignore
),
)
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
llm_coro = llm_router.aanthropic_messages(**data)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
llm_coro = llm_router.aanthropic_messages(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_coro = llm_router.aanthropic_messages(**data, specific_deployment=True)
elif (
llm_router is not None and llm_router.has_model_id(data["model"])
): # model in router model list
llm_coro = llm_router.aanthropic_messages(**data)
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.pattern_router.patterns) > 0
)
): # model in router deployments, calling a specific deployment on the router
llm_coro = llm_router.aanthropic_messages(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
llm_coro = litellm.anthropic_messages(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
tasks.append(llm_coro)
# wait for call to end
llm_responses = asyncio.gather(
*tasks
) # run the moderation check in parallel to the actual llm api call
responses = await llm_responses
response = responses[1]
# Extract model_id from request metadata (set by router during routing)
litellm_metadata = data.get("litellm_metadata", {}) or {}
model_info = litellm_metadata.get("model_info", {}) or {}
model_id = model_info.get("id", "") or ""
# Get other metadata from hidden_params
hidden_params = getattr(response, "_hidden_params", {}) or {}
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
verbose_proxy_logger.debug("final response: %s", response)
fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
hidden_params=hidden_params,
)
)
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
selected_data_generator = (
ProxyBaseLLMRequestProcessing.async_sse_data_generator(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=data,
proxy_logging_obj=proxy_logging_obj,
)
)
return await create_streaming_response(
generator=selected_data_generator,
media_type="text/event-stream",
headers=dict(fastapi_response.headers),
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
)
verbose_proxy_logger.debug("\nResponse from Litellm:\n{}".format(response))
return response
return result
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
@ -280,35 +133,30 @@ async def count_tokens(
Returns: {"input_tokens": <number>}
"""
from litellm.proxy.proxy_server import token_counter as internal_token_counter
try:
request_data = await _read_request_body(request=request)
data: dict = {**request_data}
# Extract required fields
model_name = data.get("model")
messages = data.get("messages", [])
if not model_name:
raise HTTPException(
status_code=400,
detail={"error": "model parameter is required"}
status_code=400, detail={"error": "model parameter is required"}
)
if not messages:
raise HTTPException(
status_code=400,
detail={"error": "messages parameter is required"}
status_code=400, detail={"error": "messages parameter is required"}
)
# Create TokenCountRequest for the internal endpoint
from litellm.proxy._types import TokenCountRequest
token_request = TokenCountRequest(
model=model_name,
messages=messages
)
token_request = TokenCountRequest(model=model_name, messages=messages)
# Call the internal token counter function with direct request flag set to False
token_response = await internal_token_counter(
request=token_request,
@ -319,17 +167,18 @@ async def count_tokens(
_token_response_dict = token_response.model_dump()
elif isinstance(token_response, dict):
_token_response_dict = token_response
# Convert the internal response to Anthropic API format
return {"input_tokens": _token_response_dict.get("total_tokens", 0)}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(str(e))
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(
str(e)
)
)
raise HTTPException(
status_code=500,
detail={"error": f"Internal server error: {str(e)}"}
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
)

View File

@ -337,6 +337,7 @@ class ProxyBaseLLMRequestProcessing:
"alist_skills",
"aget_skill",
"adelete_skill",
"anthropic_messages",
],
version: Optional[str] = None,
user_model: Optional[str] = None,
@ -461,11 +462,12 @@ class ProxyBaseLLMRequestProcessing:
"alist_skills",
"aget_skill",
"adelete_skill",
"anthropic_messages",
],
proxy_logging_obj: ProxyLogging,
general_settings: dict,
proxy_config: ProxyConfig,
select_data_generator: Callable,
select_data_generator: Optional[Callable] = None,
llm_router: Optional[Router] = None,
model: Optional[str] = None,
user_model: Optional[str] = None,
@ -605,7 +607,21 @@ class ProxyBaseLLMRequestProcessing:
status_code=response.status_code,
headers=custom_headers,
)
else:
elif route_type == "anthropic_messages":
selected_data_generator = (
ProxyBaseLLMRequestProcessing.async_sse_data_generator(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=self.data,
proxy_logging_obj=proxy_logging_obj,
)
)
return await create_streaming_response(
generator=selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
elif select_data_generator:
selected_data_generator = select_data_generator(
response=response,
user_api_key_dict=user_api_key_dict,
@ -804,21 +820,30 @@ class ProxyBaseLLMRequestProcessing:
# This matches the original behavior before the refactor in commit 511d435f6f
error_body = await e.response.aread()
error_text = error_body.decode("utf-8")
raise HTTPException(
status_code=e.response.status_code,
detail={"error": error_text},
)
error_msg = f"{str(e)}"
error_msg = f"{str(e)}"
# Check for AttributeError in various places:
# 1. Direct AttributeError (already handled above)
# 2. In underlying exception (__cause__, __context__, original_exception)
has_attribute_error = (
(isinstance(e, Exception) and isinstance(getattr(e, "__cause__", None), AttributeError))
or (isinstance(e, Exception) and isinstance(getattr(e, "__context__", None), AttributeError))
or (isinstance(e, Exception) and isinstance(getattr(e, "original_exception", None), AttributeError))
(
isinstance(e, Exception)
and isinstance(getattr(e, "__cause__", None), AttributeError)
)
or (
isinstance(e, Exception)
and isinstance(getattr(e, "__context__", None), AttributeError)
)
or (
isinstance(e, Exception)
and isinstance(getattr(e, "original_exception", None), AttributeError)
)
)
if has_attribute_error:
raise ProxyException(
message=f"Invalid request format: {error_msg}",
@ -1110,7 +1135,9 @@ class ProxyBaseLLMRequestProcessing:
return obj
return None
def maybe_get_model_id(self, _logging_obj: Optional[LiteLLMLoggingObj]) -> Optional[str]:
def maybe_get_model_id(
self, _logging_obj: Optional[LiteLLMLoggingObj]
) -> Optional[str]:
"""
Get model_id from logging object or request metadata.
@ -1120,10 +1147,7 @@ class ProxyBaseLLMRequestProcessing:
model_id = None
if _logging_obj:
# 1. Try getting from litellm_params (updated during call)
if (
hasattr(_logging_obj, "litellm_params")
and _logging_obj.litellm_params
):
if hasattr(_logging_obj, "litellm_params") and _logging_obj.litellm_params:
# First check direct model_info path (set by router.py with selected deployment)
model_info = _logging_obj.litellm_params.get("model_info") or {}
model_id = model_info.get("id", None)

View File

@ -729,9 +729,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
#########################################################
########## 1. Make the Bedrock API request ##########
#########################################################
bedrock_guardrail_response: Optional[
Union[BedrockGuardrailResponse, str]
] = None
bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = (
None
)
try:
bedrock_guardrail_response = await self.make_bedrock_api_request(
source="INPUT", messages=filtered_messages, request_data=data
@ -801,9 +801,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
#########################################################
########## 1. Make the Bedrock API request ##########
#########################################################
bedrock_guardrail_response: Optional[
Union[BedrockGuardrailResponse, str]
] = None
bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = (
None
)
try:
bedrock_guardrail_response = await self.make_bedrock_api_request(
source="INPUT", messages=filtered_messages, request_data=data
@ -1276,10 +1276,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
masked_texts = []
for text in texts:
mock_messages: List[AllMessageValues] = [
ChatCompletionUserMessage(role="user", content=text)
]
mock_messages: List[AllMessageValues] = [
ChatCompletionUserMessage(role="user", content=text) for text in texts
]
request_messages = mock_messages
filter_result = self._prepare_guardrail_messages_for_role(
messages=request_messages
@ -1292,19 +1291,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
request_data=request_data,
)
bedrock_response = await self.make_bedrock_api_request(
source="INPUT",
messages=mock_messages,
request_data=request_data,
)
if bedrock_response.get("action") == "BLOCKED":
raise Exception(
f"Content blocked by Bedrock guardrail: {bedrock_response.get('reason', 'Unknown reason')}"
)
# Apply any masking that was applied by the guardrail
masked_text = text
output_list = bedrock_response.get("output")
if output_list:
# If the guardrail returned modified content, use that
@ -1312,7 +1304,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
text_content = output_item.get("text")
if text_content:
masked_text = str(text_content)
break
masked_texts.append(masked_text)
else:
outputs_list = bedrock_response.get("outputs")
if outputs_list:
@ -1321,15 +1313,13 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
text_content = output_item.get("text")
if text_content:
masked_text = str(text_content)
break
masked_texts.append(masked_text)
masked_texts.append(masked_text)
verbose_proxy_logger.debug(
"Bedrock Guardrail: Successfully applied guardrail"
)
return masked_texts, images
return masked_texts or texts, images
except Exception as e:
verbose_proxy_logger.error(

View File

@ -6,7 +6,7 @@ Unified Guardrail, leveraging LiteLLM's /applyGuardrail endpoint
3. Implements a way to call /applyGuardrail endpoint for `/chat/completions` + `/v1/messages` requests on async_post_call_streaming_iterator_hook
"""
from typing import Any, AsyncGenerator, Union
from typing import Any, AsyncGenerator, List, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
@ -126,16 +126,21 @@ class UnifiedLLMGuardrails(CustomLogger):
)
is not True
):
return
verbose_proxy_logger.debug(
"async_post_call_success_hook response: %s", response
)
call_type = _infer_call_type(call_type=None, completion_response=response)
call_type: Optional[CallTypesLiteral] = None
if user_api_key_dict.request_route is not None:
call_types = get_call_types_for_route(user_api_key_dict.request_route)
if call_types is not None:
call_type = call_types[0]
if call_type is None:
call_type = _infer_call_type(call_type=None, completion_response=response)
if call_type is None:
return response
if endpoint_guardrail_translation_mappings is None:
@ -183,9 +188,6 @@ class UnifiedLLMGuardrails(CustomLogger):
"""
global endpoint_guardrail_translation_mappings
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
guardrail_to_apply: CustomGuardrail = request_data.pop(
"guardrail_to_apply", None
@ -234,9 +236,11 @@ class UnifiedLLMGuardrails(CustomLogger):
# Infer call type from first chunk
call_type = None
chunk_counter = 0
responses_so_far: List[Any] = []
async for item in response:
chunk_counter += 1
responses_so_far.append(item)
# Infer call type from first chunk if not already done
if call_type is None and user_api_key_dict.request_route is not None:
@ -244,19 +248,22 @@ class UnifiedLLMGuardrails(CustomLogger):
if call_types is not None:
call_type = call_types[0]
# If call type not supported, just pass through all chunks
if (
call_type is None
or CallTypes(call_type)
not in endpoint_guardrail_translation_mappings
):
yield item
async for remaining_item in response:
yield remaining_item
return
if call_type is None:
call_type = _infer_call_type(call_type=None, completion_response=item)
# If call type not supported, just pass through all chunks
if (
call_type is None
or CallTypes(call_type) not in endpoint_guardrail_translation_mappings
):
yield item
async for remaining_item in response:
yield remaining_item
return
# Process chunk based on sampling rate
if chunk_counter % sampling_rate == 0:
verbose_proxy_logger.debug(
"Processing streaming chunk %s (sampling_rate=%s) with guardrail %s",
chunk_counter,
@ -268,22 +275,17 @@ class UnifiedLLMGuardrails(CustomLogger):
CallTypes(call_type)
]()
processed_item = (
processed_items = (
await endpoint_translation.process_output_streaming_response(
response=item,
responses_so_far=responses_so_far,
guardrail_to_apply=guardrail_to_apply,
litellm_logging_obj=request_data.get("litellm_logging_obj"),
user_api_key_dict=user_api_key_dict,
)
)
# Add guardrail to applied guardrails header (only once, on first processed chunk)
if chunk_counter == sampling_rate:
add_guardrail_to_applied_guardrails_header(
request_data=request_data,
guardrail_name=guardrail_to_apply.guardrail_name,
)
last_item = processed_items[-1]
yield processed_item
yield last_item
else:
yield item

View File

@ -75,12 +75,13 @@ def add_shared_session_to_data(data: dict) -> None:
"""
Add shared aiohttp session for connection reuse (prevents cold starts).
Silently continues without session reuse if import fails or session is unavailable.
Args:
data: Dictionary to add the shared session to
"""
try:
from litellm.proxy.proxy_server import shared_aiohttp_session
if shared_aiohttp_session is not None and not shared_aiohttp_session.closed:
data["shared_session"] = shared_aiohttp_session
except Exception:
@ -136,13 +137,14 @@ async def route_request(
"aget_skill",
"adelete_skill",
"aingest",
"anthropic_messages",
],
):
"""
Common helper to route the request
"""
add_shared_session_to_data(data)
team_id = get_team_id_from_data(data)
router_model_names = llm_router.model_names if llm_router is not None else []
@ -177,7 +179,12 @@ async def route_request(
return llm_router.abatch_completion(models=models, **data)
elif llm_router is not None:
# Skip model-based routing for container operations
if route_type in ["acreate_container", "alist_containers", "aretrieve_container", "adelete_container"]:
if route_type in [
"acreate_container",
"alist_containers",
"aretrieve_container",
"adelete_container",
]:
return getattr(llm_router, f"{route_type}")(**data)
if route_type in [
"avideo_list",
@ -196,7 +203,7 @@ async def route_request(
] and (data.get("model") is None or data.get("model") == ""):
# These endpoints don't need a model, use custom_llm_provider directly
return getattr(litellm, f"{route_type}")(**data)
team_model_name = (
llm_router.map_team_model(data["model"], team_id)
if team_id is not None
@ -206,9 +213,8 @@ async def route_request(
data["model"] = team_model_name
return getattr(llm_router, f"{route_type}")(**data)
elif (
data["model"] in router_model_names
or llm_router.has_model_id(data["model"])
elif data["model"] in router_model_names or llm_router.has_model_id(
data["model"]
):
return getattr(llm_router, f"{route_type}")(**data)

View File

@ -1417,6 +1417,7 @@ class ProxyLogging:
3. /image/generation
4. /files
"""
from litellm.types.guardrails import GuardrailEventHooks
guardrail_callbacks: List[CustomGuardrail] = []
@ -1451,6 +1452,7 @@ class ProxyLogging:
continue
guardrail_response: Optional[Any] = None
if "apply_guardrail" in type(callback).__dict__:
data["guardrail_to_apply"] = callback
guardrail_response = (
@ -1562,6 +1564,7 @@ class ProxyLogging:
"""
for callback in litellm.callbacks:
_callback: Optional[CustomLogger] = None
if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
@ -1575,6 +1578,7 @@ class ProxyLogging:
) or _callback.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call
):
if "apply_guardrail" in type(callback).__dict__:
request_data["guardrail_to_apply"] = callback
response = (

View File

@ -153,11 +153,7 @@ from litellm.types.utils import (
)
from litellm.types.utils import ModelInfo
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import (
ModelResponseStream,
StandardLoggingPayload,
Usage,
)
from litellm.types.utils import ModelResponseStream, StandardLoggingPayload, Usage
from litellm.utils import (
CustomStreamWrapper,
EmbeddingResponse,
@ -779,6 +775,9 @@ class Router:
self.aanthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
self.anthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
self.agenerate_content = self.factory_function(
litellm.agenerate_content, call_type="agenerate_content"
)
@ -886,6 +885,7 @@ class Router:
from litellm.vector_store_files.main import (
update as vector_store_file_update_fn,
)
self.avector_store_file_create = self.factory_function(
avector_store_file_create_fn, call_type="avector_store_file_create"
)
@ -3865,6 +3865,7 @@ class Router:
"retrieve_container",
"delete_container",
):
def sync_wrapper(
custom_llm_provider: Optional[str] = None,
client: Optional[Any] = None,
@ -5944,36 +5945,38 @@ class Router:
"""
Get API credentials and provider info from a model name in model_list.
Useful for passthrough endpoints (files, batches, etc.) that need credentials.
This method tries to find a deployment by model_id first, and if not found,
it tries to find by model_group_name (model_name).
Args:
model_id: Model ID or model name from model_list (e.g., "gpt-4o-litellm")
Returns:
Dictionary containing api_key, api_base, custom_llm_provider, etc.
Returns None if model not found.
Example:
credentials = router.get_deployment_credentials_with_provider("gpt-4o-litellm")
# Returns: {"api_key": "sk-...", "custom_llm_provider": "openai", ...}
"""
# Try to get deployment by model_id first
deployment = self.get_deployment(model_id=model_id)
# If not found, try by model_group_name
if deployment is None:
deployment = self.get_deployment_by_model_group_name(model_group_name=model_id)
deployment = self.get_deployment_by_model_group_name(
model_group_name=model_id
)
if deployment is None:
return None
# Get basic credentials
credentials = CredentialLiteLLMParams(
**deployment.litellm_params.model_dump(exclude_none=True)
).model_dump(exclude_none=True)
# Add custom_llm_provider
if deployment.litellm_params.custom_llm_provider:
credentials["custom_llm_provider"] = (
@ -5986,7 +5989,7 @@ class Router:
)[0]
else:
credentials["custom_llm_provider"] = "openai" # default
return credentials
@overload

View File

@ -731,6 +731,7 @@ API_ROUTE_TO_CALL_TYPES = {
CallTypes.llm_passthrough_route,
CallTypes.allm_passthrough_route,
],
"/v1/messages": [CallTypes.anthropic_messages],
}

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,7 @@ with guardrail transformations.
import os
import sys
from typing import Any, List, Optional, Tuple
from typing import Any, List, Literal, Optional, Tuple
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -16,6 +16,8 @@ sys.path.insert(
0, os.path.abspath("../../../../../..")
) # Adds the parent directory to the system path
from fastapi import HTTPException
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms import get_guardrail_translation_mapping
from litellm.llms.openai.responses.guardrail_translation.handler import (
@ -27,12 +29,27 @@ from litellm.types.utils import CallTypes
class MockGuardrail(CustomGuardrail):
"""Mock guardrail for testing that transforms text"""
"""Mock guardrail for testing that transforms text for requests and blocks responses"""
async def apply_guardrail(
self, texts: List[str], request_data: dict, input_type: str, **kwargs
self,
texts: List[str],
request_data: dict,
input_type: Literal["request", "response"],
logging_obj: Optional[Any] = None,
images: Optional[List[str]] = None,
) -> Tuple[List[str], Optional[List[str]]]:
"""Append [GUARDRAILED] to text"""
"""
For requests: Append [GUARDRAILED] to text
For responses: Block by raising HTTPException (masking responses is no longer supported)
"""
if input_type == "response":
# Responses should be blocked, not masked
raise HTTPException(
status_code=400,
detail={"error": "Response blocked by guardrail", "texts": texts},
)
# For requests, we can still mask/transform
return ([f"{text} [GUARDRAILED]" for text in texts], None)
@ -169,11 +186,15 @@ class TestOpenAIResponsesHandlerOutputProcessing:
@pytest.mark.asyncio
async def test_process_output_response_simple(self):
"""Test processing simple output response"""
"""Test processing simple output response - should block, not mask
After unified_guardrail.py changes, responses can only be blocked/rejected, not masked.
This test verifies that the guardrail properly blocks responses.
"""
handler = OpenAIResponsesHandler()
guardrail = MockGuardrail(guardrail_name="test")
# Create a mock response
# Create a mock response with dict format (works with current handler)
response = ResponsesAPIResponse(
id="resp_123",
created_at=1234567890,
@ -181,30 +202,36 @@ class TestOpenAIResponsesHandlerOutputProcessing:
object="response",
status="completed",
output=[
GenericResponseOutputItem(
type="message",
id="msg_123",
status="completed",
role="assistant",
content=[
OutputText(
type="output_text", text="Hello user", annotations=None
),
{
"type": "message",
"id": "msg_123",
"status": "completed",
"role": "assistant",
"content": [
{"type": "output_text", "text": "Hello user"},
],
)
}
],
)
result = await handler.process_output_response(response, guardrail)
# Response should be blocked, not masked
with pytest.raises(HTTPException) as exc_info:
await handler.process_output_response(response, guardrail)
assert result.output[0].content[0].text == "Hello user [GUARDRAILED]"
assert exc_info.value.status_code == 400
assert "Response blocked by guardrail" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_process_output_response_multiple_items(self):
"""Test processing output response with multiple output items"""
"""Test processing output response with multiple output items - should block, not mask
After unified_guardrail.py changes, responses can only be blocked/rejected, not masked.
This test verifies that the guardrail properly blocks responses with multiple items.
"""
handler = OpenAIResponsesHandler()
guardrail = MockGuardrail(guardrail_name="test")
# Use dict format (works with current handler)
response = ResponsesAPIResponse(
id="resp_123",
created_at=1234567890,
@ -212,46 +239,45 @@ class TestOpenAIResponsesHandlerOutputProcessing:
object="response",
status="completed",
output=[
GenericResponseOutputItem(
type="message",
id="msg_123",
status="completed",
role="assistant",
content=[
OutputText(
type="output_text",
text="First message",
annotations=None,
),
{
"type": "message",
"id": "msg_123",
"status": "completed",
"role": "assistant",
"content": [
{"type": "output_text", "text": "First message"},
],
),
GenericResponseOutputItem(
type="message",
id="msg_124",
status="completed",
role="assistant",
content=[
OutputText(
type="output_text",
text="Second message",
annotations=None,
),
},
{
"type": "message",
"id": "msg_124",
"status": "completed",
"role": "assistant",
"content": [
{"type": "output_text", "text": "Second message"},
],
),
},
],
)
result = await handler.process_output_response(response, guardrail)
# Response should be blocked, not masked
with pytest.raises(HTTPException) as exc_info:
await handler.process_output_response(response, guardrail)
assert result.output[0].content[0].text == "First message [GUARDRAILED]"
assert result.output[1].content[0].text == "Second message [GUARDRAILED]"
assert exc_info.value.status_code == 400
assert "Response blocked by guardrail" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_process_output_response_multiple_content_items(self):
"""Test processing output response with multiple content items in one output"""
"""Test processing output response with multiple content items - should block, not mask
After unified_guardrail.py changes, responses can only be blocked/rejected, not masked.
This test verifies that the guardrail properly blocks responses with multiple content items.
"""
handler = OpenAIResponsesHandler()
guardrail = MockGuardrail(guardrail_name="test")
# Use dict format (works with current handler)
response = ResponsesAPIResponse(
id="resp_123",
created_at=1234567890,
@ -259,27 +285,33 @@ class TestOpenAIResponsesHandlerOutputProcessing:
object="response",
status="completed",
output=[
GenericResponseOutputItem(
type="message",
id="msg_123",
status="completed",
role="assistant",
content=[
OutputText(type="output_text", text="Part 1", annotations=None),
OutputText(type="output_text", text="Part 2", annotations=None),
{
"type": "message",
"id": "msg_123",
"status": "completed",
"role": "assistant",
"content": [
{"type": "output_text", "text": "Part 1"},
{"type": "output_text", "text": "Part 2"},
],
)
}
],
)
result = await handler.process_output_response(response, guardrail)
# Response should be blocked, not masked
with pytest.raises(HTTPException) as exc_info:
await handler.process_output_response(response, guardrail)
assert result.output[0].content[0].text == "Part 1 [GUARDRAILED]"
assert result.output[0].content[1].text == "Part 2 [GUARDRAILED]"
assert exc_info.value.status_code == 400
assert "Response blocked by guardrail" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_process_output_response_with_dict_format(self):
"""Test processing output response where content items are dicts instead of OutputText objects"""
"""Test processing output response with dict format - should block, not mask
After unified_guardrail.py changes, responses can only be blocked/rejected, not masked.
This test verifies blocking works even when content items are dicts instead of OutputText objects.
"""
handler = OpenAIResponsesHandler()
guardrail = MockGuardrail(guardrail_name="test")
@ -303,9 +335,12 @@ class TestOpenAIResponsesHandlerOutputProcessing:
],
)
result = await handler.process_output_response(response, guardrail)
# Response should be blocked, not masked
with pytest.raises(HTTPException) as exc_info:
await handler.process_output_response(response, guardrail)
assert result.output[0]["content"][0]["text"] == "Hello from dict [GUARDRAILED]"
assert exc_info.value.status_code == 400
assert "Response blocked by guardrail" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_process_output_response_no_text_content(self):

View File

@ -381,10 +381,17 @@ class TestContentFilterGuardrail:
assert result is not None
assert result[1] == "aws_access_key"
@pytest.mark.skip(
reason="Masking in streaming responses is no longer supported after unified_guardrail.py changes. Only blocking/rejecting is supported for responses."
)
@pytest.mark.asyncio
async def test_streaming_hook_mask(self):
"""
Test streaming hook with MASK action
Note: After changes to unified_guardrail.py, masking responses to users
is no longer supported. This test is skipped as the feature is deprecated.
Only BLOCK actions (test_streaming_hook_block) are supported for streaming responses.
"""
from unittest.mock import AsyncMock
@ -431,7 +438,7 @@ class TestContentFilterGuardrail:
user_api_key_dict = MagicMock()
request_data = {}
# Process streaming response
# Process streaming response - no masking expected
result_chunks = []
async for chunk in guardrail.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict,
@ -440,12 +447,8 @@ class TestContentFilterGuardrail:
):
result_chunks.append(chunk)
# Chunks should pass through unchanged since masking is no longer supported
assert len(result_chunks) == 2
# First chunk should have email masked
assert "[EMAIL_REDACTED]" in result_chunks[0].choices[0].delta.content
assert "test@example.com" not in result_chunks[0].choices[0].delta.content
# Second chunk should be unchanged
assert result_chunks[1].choices[0].delta.content == " for more info"
@pytest.mark.asyncio
async def test_streaming_hook_block(self):