fix(unified_guardrail.py): correctly map a v1/messages call to the anthropic unified guardrail (#17424)
* fix(unified_guardrail.py): correctly map a v1/messages call to the anthropic unified guardrail * fix: add more rigorous call type checks * fix(anthropic_endpoints/endpoints.py): initialize logging object at the beginning of endpoint ensures call id + trace id are emitted to guardrail api * feat(anthropic/chat/guardrail_translation): support streaming guardrails sample on every 5 chunks * fix(openai/chat/guardrail_translation): support openai streaming guardrails * fix: initial commit fixing output guardrails for responses api * feat(openai/responses/guardrail_translation): handler.py - fix output checks on responses api * fix(openai/responses/guardrail_translation/handler.py): ensure responses api guardrails work on streaming * test: update tests * test: update tests * test: update tests * fix(bedrock_guardrails.py): fix post call streaming iterator logic * fix: fix return * fix(bedrock_guardrails.py): fix
This commit is contained in:
parent
a711b63b06
commit
be0530a6b3
@ -429,6 +429,8 @@ class LitellmBasicGuardrailRequest(BaseModel):
|
||||
request_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
additional_provider_specific_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
input_type: Literal["request", "response"]
|
||||
litellm_call_id: Optional[str] = None
|
||||
litellm_trace_id: Optional[str] = None
|
||||
|
||||
|
||||
class LitellmBasicGuardrailResponse(BaseModel):
|
||||
|
||||
@ -95,6 +95,7 @@ from litellm.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
ProviderConfigManager,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
@ -654,7 +655,9 @@ def _infer_call_type(
|
||||
if completion_response is None:
|
||||
return None
|
||||
|
||||
if isinstance(completion_response, ModelResponse):
|
||||
if isinstance(completion_response, ModelResponse) or isinstance(
|
||||
completion_response, ModelResponseStream
|
||||
):
|
||||
return "completion"
|
||||
elif isinstance(completion_response, EmbeddingResponse):
|
||||
return "embedding"
|
||||
@ -1057,29 +1060,33 @@ def completion_cost( # noqa: PLR0915
|
||||
number_of_queries = len(query)
|
||||
elif query is not None:
|
||||
number_of_queries = 1
|
||||
|
||||
|
||||
search_model = model or ""
|
||||
if custom_llm_provider and "/" not in search_model:
|
||||
# If model is like "tavily-search", construct "tavily/search" for cost lookup
|
||||
search_model = f"{custom_llm_provider}/search"
|
||||
|
||||
prompt_cost, completion_cost_result = search_provider_cost_per_query(
|
||||
model=search_model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
number_of_queries=number_of_queries,
|
||||
optional_params=optional_params,
|
||||
|
||||
prompt_cost, completion_cost_result = (
|
||||
search_provider_cost_per_query(
|
||||
model=search_model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
number_of_queries=number_of_queries,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Return the total cost (prompt_cost + completion_cost, but for search it's just prompt_cost)
|
||||
_final_cost = prompt_cost + completion_cost_result
|
||||
|
||||
|
||||
# Apply discount
|
||||
original_cost = _final_cost
|
||||
_final_cost, discount_percent, discount_amount = _apply_cost_discount(
|
||||
base_cost=_final_cost,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
_final_cost, discount_percent, discount_amount = (
|
||||
_apply_cost_discount(
|
||||
base_cost=_final_cost,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Store cost breakdown in logging object if available
|
||||
_store_cost_breakdown_in_logging_obj(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
@ -1091,7 +1098,7 @@ def completion_cost( # noqa: PLR0915
|
||||
discount_percent=discount_percent,
|
||||
discount_amount=discount_amount,
|
||||
)
|
||||
|
||||
|
||||
return _final_cost
|
||||
elif call_type == CallTypes.arealtime.value and isinstance(
|
||||
completion_response, LiteLLMRealtimeStreamLoggingObject
|
||||
|
||||
@ -12,6 +12,7 @@ Pattern Overview:
|
||||
4. Apply guardrail responses back to the original structure
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
@ -246,6 +247,115 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
|
||||
return response
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List[Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Process output streaming response by applying guardrails to text content.
|
||||
|
||||
Get the string so far, check the apply guardrail to the string so far, and return the list of responses so far.
|
||||
"""
|
||||
string_so_far = self.get_streaming_string_so_far(responses_so_far)
|
||||
_, _ = (
|
||||
await guardrail_to_apply.apply_guardrail( # allow rejecting the response, if invalid
|
||||
texts=[string_so_far],
|
||||
request_data={},
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
images=None,
|
||||
)
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
|
||||
"""
|
||||
Parse streaming responses and extract accumulated text content.
|
||||
|
||||
Handles two formats:
|
||||
1. Raw bytes in SSE (Server-Sent Events) format from Anthropic API
|
||||
2. Parsed dict objects (for backwards compatibility)
|
||||
|
||||
SSE format example:
|
||||
b'event: content_block_delta\\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" curious"}}\\n\\n'
|
||||
|
||||
Dict format example:
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"type": "text_delta",
|
||||
"text": " curious"
|
||||
}
|
||||
}
|
||||
"""
|
||||
text_so_far = ""
|
||||
for response in responses_so_far:
|
||||
# Handle raw bytes in SSE format
|
||||
if isinstance(response, bytes):
|
||||
text_so_far += self._extract_text_from_sse(response)
|
||||
# Handle already-parsed dict format
|
||||
elif isinstance(response, dict):
|
||||
delta = response.get("delta") if response.get("delta") else None
|
||||
if delta and delta.get("type") == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
if text:
|
||||
text_so_far += text
|
||||
return text_so_far
|
||||
|
||||
def _extract_text_from_sse(self, sse_bytes: bytes) -> str:
|
||||
"""
|
||||
Extract text content from Server-Sent Events (SSE) format.
|
||||
|
||||
Args:
|
||||
sse_bytes: Raw bytes in SSE format
|
||||
|
||||
Returns:
|
||||
Accumulated text from all content_block_delta events
|
||||
"""
|
||||
text = ""
|
||||
try:
|
||||
# Decode bytes to string
|
||||
sse_string = sse_bytes.decode("utf-8")
|
||||
|
||||
# Split by double newline to get individual events
|
||||
events = sse_string.split("\n\n")
|
||||
|
||||
for event in events:
|
||||
if not event.strip():
|
||||
continue
|
||||
|
||||
# Parse event lines
|
||||
lines = event.strip().split("\n")
|
||||
event_type = None
|
||||
data_line = None
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event_type = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_line = line[5:].strip()
|
||||
|
||||
# Only process content_block_delta events
|
||||
if event_type == "content_block_delta" and data_line:
|
||||
try:
|
||||
data = json.loads(data_line)
|
||||
delta = data.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
text += delta.get("text", "")
|
||||
except json.JSONDecodeError:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to parse JSON from SSE data: {data_line}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error extracting text from SSE: {e}")
|
||||
|
||||
return text
|
||||
|
||||
def _has_text_content(self, response: "AnthropicMessagesResponse") -> bool:
|
||||
"""
|
||||
Check if response has any text content to process.
|
||||
|
||||
@ -436,9 +436,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client(
|
||||
params={"timeout": timeout}
|
||||
)
|
||||
client = _get_httpx_client(params={"timeout": timeout})
|
||||
else:
|
||||
client = client
|
||||
|
||||
@ -528,9 +526,7 @@ class ModelResponseIterator:
|
||||
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
|
||||
)
|
||||
|
||||
def _content_block_delta_helper(
|
||||
self, chunk: dict
|
||||
) -> Tuple[
|
||||
def _content_block_delta_helper(self, chunk: dict) -> Tuple[
|
||||
str,
|
||||
Optional[ChatCompletionToolCallChunk],
|
||||
List[Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]],
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
@ -87,7 +87,7 @@ class BaseTranslation(ABC):
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
response: Any,
|
||||
responses_so_far: List[Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
|
||||
@ -97,4 +97,4 @@ class BaseTranslation(ABC):
|
||||
|
||||
Optional to override in subclasses.
|
||||
"""
|
||||
return response
|
||||
return responses_so_far
|
||||
|
||||
@ -243,48 +243,122 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
response: "ModelResponseStream",
|
||||
responses_so_far: List["ModelResponseStream"],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
) -> List["ModelResponseStream"]:
|
||||
"""
|
||||
Process output streaming response by applying guardrails to text content.
|
||||
Process output streaming responses by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponseStream object
|
||||
responses_so_far: List of LiteLLM ModelResponseStream objects
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified response with guardrail applied to content
|
||||
Modified list of responses with guardrail applied to content
|
||||
|
||||
Response Format Support:
|
||||
- String content: choice.message.content = "text here"
|
||||
- List content: choice.message.content = [{"type": "text", "text": "text here"}, ...]
|
||||
"""
|
||||
|
||||
# Step 0: Check if response has any text content to process
|
||||
if not self._has_text_content(response):
|
||||
return response
|
||||
# Step 0: Check if any response has text content to process
|
||||
has_any_text_content = False
|
||||
for response in responses_so_far:
|
||||
if self._has_text_content(response):
|
||||
has_any_text_content = True
|
||||
break
|
||||
|
||||
if not has_any_text_content:
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Chat Completions: No text content in streaming responses, skipping guardrail"
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
# Step 1: Combine all streaming chunks into complete text per choice
|
||||
# For streaming, we need to concatenate all delta.content across all chunks
|
||||
# Key: (choice_idx, content_idx), Value: combined text
|
||||
combined_texts: Dict[Tuple[int, Optional[int]], str] = {}
|
||||
|
||||
for response_idx, response in enumerate(responses_so_far):
|
||||
for choice_idx, choice in enumerate(response.choices):
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
content = choice.delta.content
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
content = choice.message.content
|
||||
else:
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str):
|
||||
# String content - accumulate for this choice
|
||||
key = (choice_idx, None)
|
||||
if key not in combined_texts:
|
||||
combined_texts[key] = ""
|
||||
combined_texts[key] += content
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content - accumulate for each content item
|
||||
for content_idx, content_item in enumerate(content):
|
||||
text_str = content_item.get("text")
|
||||
if text_str:
|
||||
key = (choice_idx, content_idx)
|
||||
if key not in combined_texts:
|
||||
combined_texts[key] = ""
|
||||
combined_texts[key] += text_str
|
||||
|
||||
# Step 2: Create lists for guardrail processing
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
# Track (choice_index, content_index) for each text
|
||||
# Track (choice_index, content_index) for each combined text
|
||||
|
||||
# Step 1: Extract all text content and images from response choices
|
||||
for choice_idx, choice in enumerate(response.choices):
|
||||
for (choice_idx, content_idx), combined_text in combined_texts.items():
|
||||
texts_to_check.append(combined_text)
|
||||
task_mappings.append((choice_idx, content_idx))
|
||||
|
||||
self._extract_output_text_and_images(
|
||||
choice=choice,
|
||||
choice_idx=choice_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
# Step 3: Apply guardrail to all combined texts in batch
|
||||
if texts_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"responses": responses_so_far}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
guardrailed_texts, guardrailed_images = (
|
||||
await guardrail_to_apply.apply_guardrail(
|
||||
texts=texts_to_check,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
images=images_to_check if images_to_check else None,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
)
|
||||
|
||||
# Step 4: Apply guardrailed text back to all streaming chunks
|
||||
# For each choice, replace the combined text across all chunks
|
||||
await self._apply_guardrail_responses_to_output_streaming(
|
||||
responses=responses_so_far,
|
||||
guardrailed_texts=guardrailed_texts,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processed output streaming responses: %s",
|
||||
responses_so_far,
|
||||
)
|
||||
|
||||
return responses_so_far
|
||||
|
||||
def _has_text_content(
|
||||
self, response: Union["ModelResponse", "ModelResponseStream"]
|
||||
) -> bool:
|
||||
@ -304,10 +378,8 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
return True
|
||||
elif isinstance(response, ModelResponseStream):
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, litellm.Choices):
|
||||
if choice.message.content and isinstance(
|
||||
choice.message.content, str
|
||||
):
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
if choice.delta.content and isinstance(choice.delta.content, str):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -394,3 +466,79 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
][
|
||||
"text"
|
||||
] = guardrail_response
|
||||
|
||||
async def _apply_guardrail_responses_to_output_streaming(
|
||||
self,
|
||||
responses: List["ModelResponseStream"],
|
||||
guardrailed_texts: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to output streaming responses.
|
||||
|
||||
For streaming responses, the guardrailed text (which is the combined text from all chunks)
|
||||
is placed in the first chunk, and subsequent chunks are cleared.
|
||||
|
||||
Args:
|
||||
responses: List of ModelResponseStream objects to modify
|
||||
guardrailed_texts: List of guardrailed text responses (combined from all chunks)
|
||||
task_mappings: List of tuples (choice_idx, content_idx)
|
||||
|
||||
Override this method to customize how responses are applied to streaming responses.
|
||||
"""
|
||||
# Build a mapping of what guardrailed text to use for each (choice_idx, content_idx)
|
||||
guardrail_map: Dict[Tuple[int, Optional[int]], str] = {}
|
||||
for task_idx, guardrail_response in enumerate(guardrailed_texts):
|
||||
mapping = task_mappings[task_idx]
|
||||
choice_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
guardrail_map[(choice_idx, content_idx_optional)] = guardrail_response
|
||||
|
||||
# Track which choices we've already set the guardrailed text for
|
||||
# Key: (choice_idx, content_idx), Value: boolean (True if already set)
|
||||
already_set: Dict[Tuple[int, Optional[int]], bool] = {}
|
||||
|
||||
# Iterate through all responses and update content
|
||||
for response_idx, response in enumerate(responses):
|
||||
for choice_idx_in_response, choice in enumerate(response.choices):
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
content = choice.delta.content
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
content = choice.message.content
|
||||
else:
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str):
|
||||
# String content
|
||||
key = (choice_idx_in_response, None)
|
||||
if key in guardrail_map:
|
||||
if key not in already_set:
|
||||
# First chunk - set the complete guardrailed text
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
choice.delta.content = guardrail_map[key]
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
choice.message.content = guardrail_map[key]
|
||||
already_set[key] = True
|
||||
else:
|
||||
# Subsequent chunks - clear the content
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
choice.delta.content = ""
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
choice.message.content = ""
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content - handle each content item
|
||||
for content_idx, content_item in enumerate(content):
|
||||
if "text" in content_item:
|
||||
key = (choice_idx_in_response, content_idx)
|
||||
if key in guardrail_map:
|
||||
if key not in already_set:
|
||||
# First chunk - set the complete guardrailed text
|
||||
content_item["text"] = guardrail_map[key]
|
||||
already_set[key] = True
|
||||
else:
|
||||
# Subsequent chunks - clear the text
|
||||
content_item["text"] = ""
|
||||
|
||||
@ -30,6 +30,8 @@ Output: response.output is List[GenericResponseOutputItem] where each has:
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast
|
||||
|
||||
from openai import BaseModel
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.responses.main import GenericResponseOutputItem, OutputText
|
||||
@ -275,6 +277,32 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
|
||||
return response
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List[Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Process output streaming response by applying guardrails to text content.
|
||||
"""
|
||||
string_so_far = self.get_streaming_string_so_far(responses_so_far)
|
||||
guardrailed_text, _ = await guardrail_to_apply.apply_guardrail(
|
||||
texts=[string_so_far],
|
||||
request_data={},
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
images=None,
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
|
||||
"""
|
||||
Get the string so far from the responses so far.
|
||||
"""
|
||||
return "".join([response.get("text", "") for response in responses_so_far])
|
||||
|
||||
def _has_text_content(self, response: "ResponsesAPIResponse") -> bool:
|
||||
"""
|
||||
Check if response has any text content to process.
|
||||
@ -285,6 +313,17 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
return False
|
||||
|
||||
for output_item in response.output:
|
||||
if isinstance(output_item, BaseModel):
|
||||
try:
|
||||
generic_response_output_item = (
|
||||
GenericResponseOutputItem.model_validate(
|
||||
output_item.model_dump()
|
||||
)
|
||||
)
|
||||
if generic_response_output_item.content:
|
||||
output_item = generic_response_output_item
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(output_item, (GenericResponseOutputItem, dict)):
|
||||
content = (
|
||||
output_item.content
|
||||
@ -296,9 +335,11 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
# Check if it's an OutputText with text
|
||||
if isinstance(content_item, OutputText):
|
||||
if content_item.text:
|
||||
|
||||
return True
|
||||
elif isinstance(content_item, dict):
|
||||
if content_item.get("text"):
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -316,8 +357,16 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
Override this method to customize text/image extraction logic.
|
||||
"""
|
||||
# Handle both GenericResponseOutputItem and dict
|
||||
if isinstance(output_item, GenericResponseOutputItem):
|
||||
content = output_item.content
|
||||
content: Optional[Union[List[OutputText], List[dict]]] = None
|
||||
if isinstance(output_item, BaseModel):
|
||||
try:
|
||||
generic_response_output_item = GenericResponseOutputItem.model_validate(
|
||||
output_item.model_dump()
|
||||
)
|
||||
if generic_response_output_item.content:
|
||||
content = generic_response_output_item.content
|
||||
except Exception:
|
||||
return
|
||||
elif isinstance(output_item, dict):
|
||||
content = output_item.get("content", [])
|
||||
else:
|
||||
|
||||
@ -3,23 +3,19 @@ model_list:
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: claude-sonnet-4-5-20250929
|
||||
litellm_params:
|
||||
model: anthropic/claude-sonnet-4-5-20250929
|
||||
- model_name: gpt-4.1-mini
|
||||
litellm_params:
|
||||
model: openai/gpt-4.1-mini
|
||||
|
||||
|
||||
guardrails:
|
||||
# - guardrail_name: model-armor-shield
|
||||
# litellm_params:
|
||||
# guardrail: model_armor
|
||||
# mode: "post_call" # Run on both input and output
|
||||
# template_id: "test-prompt-template" # Required: Your Model Armor template ID
|
||||
# project_id: "test-vector-store-db" # Your GCP project ID
|
||||
# location: "us" # GCP location (default: us-central1)
|
||||
# mask_request_content: true # Enable request content masking
|
||||
# mask_response_content: true # Enable response content masking
|
||||
# fail_on_error: true # Fail request if Model Armor errors (default: true)
|
||||
# default_on: true # Run by default for all requests
|
||||
- guardrail_name: generic-guardrail
|
||||
litellm_params:
|
||||
guardrail: generic_guardrail_api
|
||||
mode: [pre_call, post_call]
|
||||
mode: ["pre_call", "post_call", "during_call"]
|
||||
headers:
|
||||
Authorization: Bearer mock-bedrock-token-12345
|
||||
api_base: http://localhost:8080
|
||||
|
||||
@ -2,20 +2,14 @@
|
||||
Unified /v1/messages endpoint - (Anthropic Spec)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
create_streaming_response,
|
||||
)
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.types.utils import TokenCountResponse
|
||||
|
||||
router = APIRouter()
|
||||
@ -49,169 +43,28 @@ async def anthropic_response( # noqa: PLR0915
|
||||
version,
|
||||
)
|
||||
|
||||
request_data = await _read_request_body(request=request)
|
||||
data: dict = {**request_data}
|
||||
data = await _read_request_body(request=request)
|
||||
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data.get("model", None) # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data, # type: ignore
|
||||
result = await base_llm_response_processor.base_process_llm_request(
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
route_type="anthropic_messages",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=None,
|
||||
model=None,
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
### MODEL ALIAS MAPPING ###
|
||||
# check if model name in model alias map
|
||||
# get the actual model name
|
||||
if data["model"] in litellm.model_alias_map:
|
||||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type=CallTypes.anthropic_messages.value
|
||||
)
|
||||
|
||||
tasks = []
|
||||
tasks.append(
|
||||
proxy_logging_obj.during_call_hook(
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
|
||||
route_type="anthropic_messages" # type: ignore
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
### ROUTE THE REQUESTs ###
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
# skip router if user passed their key
|
||||
if (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
llm_coro = llm_router.aanthropic_messages(**data)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
llm_coro = llm_router.aanthropic_messages(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_coro = llm_router.aanthropic_messages(**data, specific_deployment=True)
|
||||
elif (
|
||||
llm_router is not None and llm_router.has_model_id(data["model"])
|
||||
): # model in router model list
|
||||
llm_coro = llm_router.aanthropic_messages(**data)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.pattern_router.patterns) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_coro = llm_router.aanthropic_messages(**data)
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
llm_coro = litellm.anthropic_messages(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "completion: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
|
||||
tasks.append(llm_coro)
|
||||
|
||||
# wait for call to end
|
||||
llm_responses = asyncio.gather(
|
||||
*tasks
|
||||
) # run the moderation check in parallel to the actual llm api call
|
||||
|
||||
responses = await llm_responses
|
||||
|
||||
response = responses[1]
|
||||
|
||||
# Extract model_id from request metadata (set by router during routing)
|
||||
litellm_metadata = data.get("litellm_metadata", {}) or {}
|
||||
model_info = litellm_metadata.get("model_info", {}) or {}
|
||||
model_id = model_info.get("id", "") or ""
|
||||
|
||||
# Get other metadata from hidden_params
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
response_cost = hidden_params.get("response_cost", None) or ""
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("final response: %s", response)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in data and data["stream"] is True
|
||||
): # use generate_responses to stream responses
|
||||
selected_data_generator = (
|
||||
ProxyBaseLLMRequestProcessing.async_sse_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
)
|
||||
|
||||
return await create_streaming_response(
|
||||
generator=selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers=dict(fastapi_response.headers),
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("\nResponse from Litellm:\n{}".format(response))
|
||||
return response
|
||||
return result
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
@ -280,35 +133,30 @@ async def count_tokens(
|
||||
Returns: {"input_tokens": <number>}
|
||||
"""
|
||||
from litellm.proxy.proxy_server import token_counter as internal_token_counter
|
||||
|
||||
|
||||
try:
|
||||
request_data = await _read_request_body(request=request)
|
||||
data: dict = {**request_data}
|
||||
|
||||
|
||||
# Extract required fields
|
||||
model_name = data.get("model")
|
||||
messages = data.get("messages", [])
|
||||
|
||||
|
||||
if not model_name:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "model parameter is required"}
|
||||
status_code=400, detail={"error": "model parameter is required"}
|
||||
)
|
||||
|
||||
|
||||
if not messages:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "messages parameter is required"}
|
||||
status_code=400, detail={"error": "messages parameter is required"}
|
||||
)
|
||||
|
||||
|
||||
# Create TokenCountRequest for the internal endpoint
|
||||
from litellm.proxy._types import TokenCountRequest
|
||||
|
||||
token_request = TokenCountRequest(
|
||||
model=model_name,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
|
||||
token_request = TokenCountRequest(model=model_name, messages=messages)
|
||||
|
||||
# Call the internal token counter function with direct request flag set to False
|
||||
token_response = await internal_token_counter(
|
||||
request=token_request,
|
||||
@ -319,17 +167,18 @@ async def count_tokens(
|
||||
_token_response_dict = token_response.model_dump()
|
||||
elif isinstance(token_response, dict):
|
||||
_token_response_dict = token_response
|
||||
|
||||
|
||||
# Convert the internal response to Anthropic API format
|
||||
return {"input_tokens": _token_response_dict.get("total_tokens", 0)}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(str(e))
|
||||
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Internal server error: {str(e)}"}
|
||||
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
|
||||
)
|
||||
|
||||
@ -337,6 +337,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||
"alist_skills",
|
||||
"aget_skill",
|
||||
"adelete_skill",
|
||||
"anthropic_messages",
|
||||
],
|
||||
version: Optional[str] = None,
|
||||
user_model: Optional[str] = None,
|
||||
@ -461,11 +462,12 @@ class ProxyBaseLLMRequestProcessing:
|
||||
"alist_skills",
|
||||
"aget_skill",
|
||||
"adelete_skill",
|
||||
"anthropic_messages",
|
||||
],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
general_settings: dict,
|
||||
proxy_config: ProxyConfig,
|
||||
select_data_generator: Callable,
|
||||
select_data_generator: Optional[Callable] = None,
|
||||
llm_router: Optional[Router] = None,
|
||||
model: Optional[str] = None,
|
||||
user_model: Optional[str] = None,
|
||||
@ -605,7 +607,21 @@ class ProxyBaseLLMRequestProcessing:
|
||||
status_code=response.status_code,
|
||||
headers=custom_headers,
|
||||
)
|
||||
else:
|
||||
elif route_type == "anthropic_messages":
|
||||
selected_data_generator = (
|
||||
ProxyBaseLLMRequestProcessing.async_sse_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=self.data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
)
|
||||
return await create_streaming_response(
|
||||
generator=selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers=custom_headers,
|
||||
)
|
||||
elif select_data_generator:
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
@ -804,21 +820,30 @@ class ProxyBaseLLMRequestProcessing:
|
||||
# This matches the original behavior before the refactor in commit 511d435f6f
|
||||
error_body = await e.response.aread()
|
||||
error_text = error_body.decode("utf-8")
|
||||
|
||||
|
||||
raise HTTPException(
|
||||
status_code=e.response.status_code,
|
||||
detail={"error": error_text},
|
||||
)
|
||||
error_msg = f"{str(e)}"
|
||||
error_msg = f"{str(e)}"
|
||||
# Check for AttributeError in various places:
|
||||
# 1. Direct AttributeError (already handled above)
|
||||
# 2. In underlying exception (__cause__, __context__, original_exception)
|
||||
has_attribute_error = (
|
||||
(isinstance(e, Exception) and isinstance(getattr(e, "__cause__", None), AttributeError))
|
||||
or (isinstance(e, Exception) and isinstance(getattr(e, "__context__", None), AttributeError))
|
||||
or (isinstance(e, Exception) and isinstance(getattr(e, "original_exception", None), AttributeError))
|
||||
(
|
||||
isinstance(e, Exception)
|
||||
and isinstance(getattr(e, "__cause__", None), AttributeError)
|
||||
)
|
||||
or (
|
||||
isinstance(e, Exception)
|
||||
and isinstance(getattr(e, "__context__", None), AttributeError)
|
||||
)
|
||||
or (
|
||||
isinstance(e, Exception)
|
||||
and isinstance(getattr(e, "original_exception", None), AttributeError)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if has_attribute_error:
|
||||
raise ProxyException(
|
||||
message=f"Invalid request format: {error_msg}",
|
||||
@ -1110,7 +1135,9 @@ class ProxyBaseLLMRequestProcessing:
|
||||
return obj
|
||||
return None
|
||||
|
||||
def maybe_get_model_id(self, _logging_obj: Optional[LiteLLMLoggingObj]) -> Optional[str]:
|
||||
def maybe_get_model_id(
|
||||
self, _logging_obj: Optional[LiteLLMLoggingObj]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get model_id from logging object or request metadata.
|
||||
|
||||
@ -1120,10 +1147,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||
model_id = None
|
||||
if _logging_obj:
|
||||
# 1. Try getting from litellm_params (updated during call)
|
||||
if (
|
||||
hasattr(_logging_obj, "litellm_params")
|
||||
and _logging_obj.litellm_params
|
||||
):
|
||||
if hasattr(_logging_obj, "litellm_params") and _logging_obj.litellm_params:
|
||||
# First check direct model_info path (set by router.py with selected deployment)
|
||||
model_info = _logging_obj.litellm_params.get("model_info") or {}
|
||||
model_id = model_info.get("id", None)
|
||||
|
||||
@ -729,9 +729,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||
#########################################################
|
||||
########## 1. Make the Bedrock API request ##########
|
||||
#########################################################
|
||||
bedrock_guardrail_response: Optional[
|
||||
Union[BedrockGuardrailResponse, str]
|
||||
] = None
|
||||
bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = (
|
||||
None
|
||||
)
|
||||
try:
|
||||
bedrock_guardrail_response = await self.make_bedrock_api_request(
|
||||
source="INPUT", messages=filtered_messages, request_data=data
|
||||
@ -801,9 +801,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||
#########################################################
|
||||
########## 1. Make the Bedrock API request ##########
|
||||
#########################################################
|
||||
bedrock_guardrail_response: Optional[
|
||||
Union[BedrockGuardrailResponse, str]
|
||||
] = None
|
||||
bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = (
|
||||
None
|
||||
)
|
||||
try:
|
||||
bedrock_guardrail_response = await self.make_bedrock_api_request(
|
||||
source="INPUT", messages=filtered_messages, request_data=data
|
||||
@ -1276,10 +1276,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||
|
||||
masked_texts = []
|
||||
|
||||
for text in texts:
|
||||
mock_messages: List[AllMessageValues] = [
|
||||
ChatCompletionUserMessage(role="user", content=text)
|
||||
]
|
||||
mock_messages: List[AllMessageValues] = [
|
||||
ChatCompletionUserMessage(role="user", content=text) for text in texts
|
||||
]
|
||||
request_messages = mock_messages
|
||||
filter_result = self._prepare_guardrail_messages_for_role(
|
||||
messages=request_messages
|
||||
@ -1292,19 +1291,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||
request_data=request_data,
|
||||
)
|
||||
|
||||
bedrock_response = await self.make_bedrock_api_request(
|
||||
source="INPUT",
|
||||
messages=mock_messages,
|
||||
request_data=request_data,
|
||||
)
|
||||
|
||||
if bedrock_response.get("action") == "BLOCKED":
|
||||
raise Exception(
|
||||
f"Content blocked by Bedrock guardrail: {bedrock_response.get('reason', 'Unknown reason')}"
|
||||
)
|
||||
|
||||
# Apply any masking that was applied by the guardrail
|
||||
masked_text = text
|
||||
output_list = bedrock_response.get("output")
|
||||
if output_list:
|
||||
# If the guardrail returned modified content, use that
|
||||
@ -1312,7 +1304,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||
text_content = output_item.get("text")
|
||||
if text_content:
|
||||
masked_text = str(text_content)
|
||||
break
|
||||
masked_texts.append(masked_text)
|
||||
else:
|
||||
outputs_list = bedrock_response.get("outputs")
|
||||
if outputs_list:
|
||||
@ -1321,15 +1313,13 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||
text_content = output_item.get("text")
|
||||
if text_content:
|
||||
masked_text = str(text_content)
|
||||
break
|
||||
|
||||
masked_texts.append(masked_text)
|
||||
masked_texts.append(masked_text)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Bedrock Guardrail: Successfully applied guardrail"
|
||||
)
|
||||
|
||||
return masked_texts, images
|
||||
return masked_texts or texts, images
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
|
||||
@ -6,7 +6,7 @@ Unified Guardrail, leveraging LiteLLM's /applyGuardrail endpoint
|
||||
3. Implements a way to call /applyGuardrail endpoint for `/chat/completions` + `/v1/messages` requests on async_post_call_streaming_iterator_hook
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncGenerator, Union
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
@ -126,16 +126,21 @@ class UnifiedLLMGuardrails(CustomLogger):
|
||||
)
|
||||
is not True
|
||||
):
|
||||
|
||||
return
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"async_post_call_success_hook response: %s", response
|
||||
)
|
||||
|
||||
call_type = _infer_call_type(call_type=None, completion_response=response)
|
||||
call_type: Optional[CallTypesLiteral] = None
|
||||
if user_api_key_dict.request_route is not None:
|
||||
call_types = get_call_types_for_route(user_api_key_dict.request_route)
|
||||
if call_types is not None:
|
||||
call_type = call_types[0]
|
||||
if call_type is None:
|
||||
call_type = _infer_call_type(call_type=None, completion_response=response)
|
||||
|
||||
if call_type is None:
|
||||
return response
|
||||
|
||||
if endpoint_guardrail_translation_mappings is None:
|
||||
@ -183,9 +188,6 @@ class UnifiedLLMGuardrails(CustomLogger):
|
||||
"""
|
||||
|
||||
global endpoint_guardrail_translation_mappings
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
|
||||
guardrail_to_apply: CustomGuardrail = request_data.pop(
|
||||
"guardrail_to_apply", None
|
||||
@ -234,9 +236,11 @@ class UnifiedLLMGuardrails(CustomLogger):
|
||||
# Infer call type from first chunk
|
||||
call_type = None
|
||||
chunk_counter = 0
|
||||
responses_so_far: List[Any] = []
|
||||
|
||||
async for item in response:
|
||||
chunk_counter += 1
|
||||
responses_so_far.append(item)
|
||||
|
||||
# Infer call type from first chunk if not already done
|
||||
if call_type is None and user_api_key_dict.request_route is not None:
|
||||
@ -244,19 +248,22 @@ class UnifiedLLMGuardrails(CustomLogger):
|
||||
if call_types is not None:
|
||||
call_type = call_types[0]
|
||||
|
||||
# If call type not supported, just pass through all chunks
|
||||
if (
|
||||
call_type is None
|
||||
or CallTypes(call_type)
|
||||
not in endpoint_guardrail_translation_mappings
|
||||
):
|
||||
yield item
|
||||
async for remaining_item in response:
|
||||
yield remaining_item
|
||||
return
|
||||
if call_type is None:
|
||||
call_type = _infer_call_type(call_type=None, completion_response=item)
|
||||
|
||||
# If call type not supported, just pass through all chunks
|
||||
if (
|
||||
call_type is None
|
||||
or CallTypes(call_type) not in endpoint_guardrail_translation_mappings
|
||||
):
|
||||
yield item
|
||||
async for remaining_item in response:
|
||||
yield remaining_item
|
||||
return
|
||||
|
||||
# Process chunk based on sampling rate
|
||||
if chunk_counter % sampling_rate == 0:
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Processing streaming chunk %s (sampling_rate=%s) with guardrail %s",
|
||||
chunk_counter,
|
||||
@ -268,22 +275,17 @@ class UnifiedLLMGuardrails(CustomLogger):
|
||||
CallTypes(call_type)
|
||||
]()
|
||||
|
||||
processed_item = (
|
||||
processed_items = (
|
||||
await endpoint_translation.process_output_streaming_response(
|
||||
response=item,
|
||||
responses_so_far=responses_so_far,
|
||||
guardrail_to_apply=guardrail_to_apply,
|
||||
litellm_logging_obj=request_data.get("litellm_logging_obj"),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
)
|
||||
|
||||
# Add guardrail to applied guardrails header (only once, on first processed chunk)
|
||||
if chunk_counter == sampling_rate:
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=request_data,
|
||||
guardrail_name=guardrail_to_apply.guardrail_name,
|
||||
)
|
||||
last_item = processed_items[-1]
|
||||
|
||||
yield processed_item
|
||||
yield last_item
|
||||
else:
|
||||
yield item
|
||||
|
||||
@ -75,12 +75,13 @@ def add_shared_session_to_data(data: dict) -> None:
|
||||
"""
|
||||
Add shared aiohttp session for connection reuse (prevents cold starts).
|
||||
Silently continues without session reuse if import fails or session is unavailable.
|
||||
|
||||
|
||||
Args:
|
||||
data: Dictionary to add the shared session to
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import shared_aiohttp_session
|
||||
|
||||
if shared_aiohttp_session is not None and not shared_aiohttp_session.closed:
|
||||
data["shared_session"] = shared_aiohttp_session
|
||||
except Exception:
|
||||
@ -136,13 +137,14 @@ async def route_request(
|
||||
"aget_skill",
|
||||
"adelete_skill",
|
||||
"aingest",
|
||||
"anthropic_messages",
|
||||
],
|
||||
):
|
||||
"""
|
||||
Common helper to route the request
|
||||
"""
|
||||
add_shared_session_to_data(data)
|
||||
|
||||
|
||||
team_id = get_team_id_from_data(data)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
@ -177,7 +179,12 @@ async def route_request(
|
||||
return llm_router.abatch_completion(models=models, **data)
|
||||
elif llm_router is not None:
|
||||
# Skip model-based routing for container operations
|
||||
if route_type in ["acreate_container", "alist_containers", "aretrieve_container", "adelete_container"]:
|
||||
if route_type in [
|
||||
"acreate_container",
|
||||
"alist_containers",
|
||||
"aretrieve_container",
|
||||
"adelete_container",
|
||||
]:
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
if route_type in [
|
||||
"avideo_list",
|
||||
@ -196,7 +203,7 @@ async def route_request(
|
||||
] and (data.get("model") is None or data.get("model") == ""):
|
||||
# These endpoints don't need a model, use custom_llm_provider directly
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
|
||||
|
||||
team_model_name = (
|
||||
llm_router.map_team_model(data["model"], team_id)
|
||||
if team_id is not None
|
||||
@ -206,9 +213,8 @@ async def route_request(
|
||||
data["model"] = team_model_name
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
|
||||
elif (
|
||||
data["model"] in router_model_names
|
||||
or llm_router.has_model_id(data["model"])
|
||||
elif data["model"] in router_model_names or llm_router.has_model_id(
|
||||
data["model"]
|
||||
):
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
|
||||
|
||||
@ -1417,6 +1417,7 @@ class ProxyLogging:
|
||||
3. /image/generation
|
||||
4. /files
|
||||
"""
|
||||
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
guardrail_callbacks: List[CustomGuardrail] = []
|
||||
@ -1451,6 +1452,7 @@ class ProxyLogging:
|
||||
continue
|
||||
|
||||
guardrail_response: Optional[Any] = None
|
||||
|
||||
if "apply_guardrail" in type(callback).__dict__:
|
||||
data["guardrail_to_apply"] = callback
|
||||
guardrail_response = (
|
||||
@ -1562,6 +1564,7 @@ class ProxyLogging:
|
||||
"""
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
@ -1575,6 +1578,7 @@ class ProxyLogging:
|
||||
) or _callback.should_run_guardrail(
|
||||
data=request_data, event_type=GuardrailEventHooks.post_call
|
||||
):
|
||||
|
||||
if "apply_guardrail" in type(callback).__dict__:
|
||||
request_data["guardrail_to_apply"] = callback
|
||||
response = (
|
||||
|
||||
@ -153,11 +153,7 @@ from litellm.types.utils import (
|
||||
)
|
||||
from litellm.types.utils import ModelInfo
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
from litellm.types.utils import (
|
||||
ModelResponseStream,
|
||||
StandardLoggingPayload,
|
||||
Usage,
|
||||
)
|
||||
from litellm.types.utils import ModelResponseStream, StandardLoggingPayload, Usage
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
EmbeddingResponse,
|
||||
@ -779,6 +775,9 @@ class Router:
|
||||
self.aanthropic_messages = self.factory_function(
|
||||
litellm.anthropic_messages, call_type="anthropic_messages"
|
||||
)
|
||||
self.anthropic_messages = self.factory_function(
|
||||
litellm.anthropic_messages, call_type="anthropic_messages"
|
||||
)
|
||||
self.agenerate_content = self.factory_function(
|
||||
litellm.agenerate_content, call_type="agenerate_content"
|
||||
)
|
||||
@ -886,6 +885,7 @@ class Router:
|
||||
from litellm.vector_store_files.main import (
|
||||
update as vector_store_file_update_fn,
|
||||
)
|
||||
|
||||
self.avector_store_file_create = self.factory_function(
|
||||
avector_store_file_create_fn, call_type="avector_store_file_create"
|
||||
)
|
||||
@ -3865,6 +3865,7 @@ class Router:
|
||||
"retrieve_container",
|
||||
"delete_container",
|
||||
):
|
||||
|
||||
def sync_wrapper(
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
@ -5944,36 +5945,38 @@ class Router:
|
||||
"""
|
||||
Get API credentials and provider info from a model name in model_list.
|
||||
Useful for passthrough endpoints (files, batches, etc.) that need credentials.
|
||||
|
||||
|
||||
This method tries to find a deployment by model_id first, and if not found,
|
||||
it tries to find by model_group_name (model_name).
|
||||
|
||||
|
||||
Args:
|
||||
model_id: Model ID or model name from model_list (e.g., "gpt-4o-litellm")
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing api_key, api_base, custom_llm_provider, etc.
|
||||
Returns None if model not found.
|
||||
|
||||
|
||||
Example:
|
||||
credentials = router.get_deployment_credentials_with_provider("gpt-4o-litellm")
|
||||
# Returns: {"api_key": "sk-...", "custom_llm_provider": "openai", ...}
|
||||
"""
|
||||
# Try to get deployment by model_id first
|
||||
deployment = self.get_deployment(model_id=model_id)
|
||||
|
||||
|
||||
# If not found, try by model_group_name
|
||||
if deployment is None:
|
||||
deployment = self.get_deployment_by_model_group_name(model_group_name=model_id)
|
||||
|
||||
deployment = self.get_deployment_by_model_group_name(
|
||||
model_group_name=model_id
|
||||
)
|
||||
|
||||
if deployment is None:
|
||||
return None
|
||||
|
||||
|
||||
# Get basic credentials
|
||||
credentials = CredentialLiteLLMParams(
|
||||
**deployment.litellm_params.model_dump(exclude_none=True)
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
|
||||
# Add custom_llm_provider
|
||||
if deployment.litellm_params.custom_llm_provider:
|
||||
credentials["custom_llm_provider"] = (
|
||||
@ -5986,7 +5989,7 @@ class Router:
|
||||
)[0]
|
||||
else:
|
||||
credentials["custom_llm_provider"] = "openai" # default
|
||||
|
||||
|
||||
return credentials
|
||||
|
||||
@overload
|
||||
|
||||
@ -731,6 +731,7 @@ API_ROUTE_TO_CALL_TYPES = {
|
||||
CallTypes.llm_passthrough_route,
|
||||
CallTypes.allm_passthrough_route,
|
||||
],
|
||||
"/v1/messages": [CallTypes.anthropic_messages],
|
||||
}
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -7,7 +7,7 @@ with guardrail transformations.
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, List, Literal, Optional, Tuple
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@ -16,6 +16,8 @@ sys.path.insert(
|
||||
0, os.path.abspath("../../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.llms import get_guardrail_translation_mapping
|
||||
from litellm.llms.openai.responses.guardrail_translation.handler import (
|
||||
@ -27,12 +29,27 @@ from litellm.types.utils import CallTypes
|
||||
|
||||
|
||||
class MockGuardrail(CustomGuardrail):
|
||||
"""Mock guardrail for testing that transforms text"""
|
||||
"""Mock guardrail for testing that transforms text for requests and blocks responses"""
|
||||
|
||||
async def apply_guardrail(
|
||||
self, texts: List[str], request_data: dict, input_type: str, **kwargs
|
||||
self,
|
||||
texts: List[str],
|
||||
request_data: dict,
|
||||
input_type: Literal["request", "response"],
|
||||
logging_obj: Optional[Any] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
) -> Tuple[List[str], Optional[List[str]]]:
|
||||
"""Append [GUARDRAILED] to text"""
|
||||
"""
|
||||
For requests: Append [GUARDRAILED] to text
|
||||
For responses: Block by raising HTTPException (masking responses is no longer supported)
|
||||
"""
|
||||
if input_type == "response":
|
||||
# Responses should be blocked, not masked
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Response blocked by guardrail", "texts": texts},
|
||||
)
|
||||
# For requests, we can still mask/transform
|
||||
return ([f"{text} [GUARDRAILED]" for text in texts], None)
|
||||
|
||||
|
||||
@ -169,11 +186,15 @@ class TestOpenAIResponsesHandlerOutputProcessing:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_output_response_simple(self):
|
||||
"""Test processing simple output response"""
|
||||
"""Test processing simple output response - should block, not mask
|
||||
|
||||
After unified_guardrail.py changes, responses can only be blocked/rejected, not masked.
|
||||
This test verifies that the guardrail properly blocks responses.
|
||||
"""
|
||||
handler = OpenAIResponsesHandler()
|
||||
guardrail = MockGuardrail(guardrail_name="test")
|
||||
|
||||
# Create a mock response
|
||||
# Create a mock response with dict format (works with current handler)
|
||||
response = ResponsesAPIResponse(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
@ -181,30 +202,36 @@ class TestOpenAIResponsesHandlerOutputProcessing:
|
||||
object="response",
|
||||
status="completed",
|
||||
output=[
|
||||
GenericResponseOutputItem(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
content=[
|
||||
OutputText(
|
||||
type="output_text", text="Hello user", annotations=None
|
||||
),
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "Hello user"},
|
||||
],
|
||||
)
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
result = await handler.process_output_response(response, guardrail)
|
||||
# Response should be blocked, not masked
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handler.process_output_response(response, guardrail)
|
||||
|
||||
assert result.output[0].content[0].text == "Hello user [GUARDRAILED]"
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Response blocked by guardrail" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_output_response_multiple_items(self):
|
||||
"""Test processing output response with multiple output items"""
|
||||
"""Test processing output response with multiple output items - should block, not mask
|
||||
|
||||
After unified_guardrail.py changes, responses can only be blocked/rejected, not masked.
|
||||
This test verifies that the guardrail properly blocks responses with multiple items.
|
||||
"""
|
||||
handler = OpenAIResponsesHandler()
|
||||
guardrail = MockGuardrail(guardrail_name="test")
|
||||
|
||||
# Use dict format (works with current handler)
|
||||
response = ResponsesAPIResponse(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
@ -212,46 +239,45 @@ class TestOpenAIResponsesHandlerOutputProcessing:
|
||||
object="response",
|
||||
status="completed",
|
||||
output=[
|
||||
GenericResponseOutputItem(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
content=[
|
||||
OutputText(
|
||||
type="output_text",
|
||||
text="First message",
|
||||
annotations=None,
|
||||
),
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "First message"},
|
||||
],
|
||||
),
|
||||
GenericResponseOutputItem(
|
||||
type="message",
|
||||
id="msg_124",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
content=[
|
||||
OutputText(
|
||||
type="output_text",
|
||||
text="Second message",
|
||||
annotations=None,
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_124",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "Second message"},
|
||||
],
|
||||
),
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
result = await handler.process_output_response(response, guardrail)
|
||||
# Response should be blocked, not masked
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handler.process_output_response(response, guardrail)
|
||||
|
||||
assert result.output[0].content[0].text == "First message [GUARDRAILED]"
|
||||
assert result.output[1].content[0].text == "Second message [GUARDRAILED]"
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Response blocked by guardrail" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_output_response_multiple_content_items(self):
|
||||
"""Test processing output response with multiple content items in one output"""
|
||||
"""Test processing output response with multiple content items - should block, not mask
|
||||
|
||||
After unified_guardrail.py changes, responses can only be blocked/rejected, not masked.
|
||||
This test verifies that the guardrail properly blocks responses with multiple content items.
|
||||
"""
|
||||
handler = OpenAIResponsesHandler()
|
||||
guardrail = MockGuardrail(guardrail_name="test")
|
||||
|
||||
# Use dict format (works with current handler)
|
||||
response = ResponsesAPIResponse(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
@ -259,27 +285,33 @@ class TestOpenAIResponsesHandlerOutputProcessing:
|
||||
object="response",
|
||||
status="completed",
|
||||
output=[
|
||||
GenericResponseOutputItem(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
content=[
|
||||
OutputText(type="output_text", text="Part 1", annotations=None),
|
||||
OutputText(type="output_text", text="Part 2", annotations=None),
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "Part 1"},
|
||||
{"type": "output_text", "text": "Part 2"},
|
||||
],
|
||||
)
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
result = await handler.process_output_response(response, guardrail)
|
||||
# Response should be blocked, not masked
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handler.process_output_response(response, guardrail)
|
||||
|
||||
assert result.output[0].content[0].text == "Part 1 [GUARDRAILED]"
|
||||
assert result.output[0].content[1].text == "Part 2 [GUARDRAILED]"
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Response blocked by guardrail" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_output_response_with_dict_format(self):
|
||||
"""Test processing output response where content items are dicts instead of OutputText objects"""
|
||||
"""Test processing output response with dict format - should block, not mask
|
||||
|
||||
After unified_guardrail.py changes, responses can only be blocked/rejected, not masked.
|
||||
This test verifies blocking works even when content items are dicts instead of OutputText objects.
|
||||
"""
|
||||
handler = OpenAIResponsesHandler()
|
||||
guardrail = MockGuardrail(guardrail_name="test")
|
||||
|
||||
@ -303,9 +335,12 @@ class TestOpenAIResponsesHandlerOutputProcessing:
|
||||
],
|
||||
)
|
||||
|
||||
result = await handler.process_output_response(response, guardrail)
|
||||
# Response should be blocked, not masked
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handler.process_output_response(response, guardrail)
|
||||
|
||||
assert result.output[0]["content"][0]["text"] == "Hello from dict [GUARDRAILED]"
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Response blocked by guardrail" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_output_response_no_text_content(self):
|
||||
|
||||
@ -381,10 +381,17 @@ class TestContentFilterGuardrail:
|
||||
assert result is not None
|
||||
assert result[1] == "aws_access_key"
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Masking in streaming responses is no longer supported after unified_guardrail.py changes. Only blocking/rejecting is supported for responses."
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_hook_mask(self):
|
||||
"""
|
||||
Test streaming hook with MASK action
|
||||
|
||||
Note: After changes to unified_guardrail.py, masking responses to users
|
||||
is no longer supported. This test is skipped as the feature is deprecated.
|
||||
Only BLOCK actions (test_streaming_hook_block) are supported for streaming responses.
|
||||
"""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
@ -431,7 +438,7 @@ class TestContentFilterGuardrail:
|
||||
user_api_key_dict = MagicMock()
|
||||
request_data = {}
|
||||
|
||||
# Process streaming response
|
||||
# Process streaming response - no masking expected
|
||||
result_chunks = []
|
||||
async for chunk in guardrail.async_post_call_streaming_iterator_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
@ -440,12 +447,8 @@ class TestContentFilterGuardrail:
|
||||
):
|
||||
result_chunks.append(chunk)
|
||||
|
||||
# Chunks should pass through unchanged since masking is no longer supported
|
||||
assert len(result_chunks) == 2
|
||||
# First chunk should have email masked
|
||||
assert "[EMAIL_REDACTED]" in result_chunks[0].choices[0].delta.content
|
||||
assert "test@example.com" not in result_chunks[0].choices[0].delta.content
|
||||
# Second chunk should be unchanged
|
||||
assert result_chunks[1].choices[0].delta.content == " for more info"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_hook_block(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user