Reapply "feat: add model_cost aliases expansion support"
This reverts commit 3d2df7e8b5.
This commit is contained in:
parent
fa68d69bcf
commit
feed274aa3
@ -348,6 +348,8 @@ class DualCache(BaseCache):
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
@ -369,6 +371,8 @@ class DualCache(BaseCache):
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
|
||||
@ -398,6 +398,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
from openai.types.responses.response_output_item import ResponseApplyPatchToolCall
|
||||
|
||||
from litellm.types.utils import Choices, Message
|
||||
|
||||
@ -456,6 +457,18 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
accumulated_tool_calls.append(tool_call_dict)
|
||||
tool_call_index += 1
|
||||
|
||||
elif isinstance(item, ResponseApplyPatchToolCall):
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
|
||||
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_apply_patch_tool_call_to_chat_completion_tool_call(
|
||||
tool_call_item=item,
|
||||
index=tool_call_index,
|
||||
)
|
||||
accumulated_tool_calls.append(tool_call_dict)
|
||||
tool_call_index += 1
|
||||
|
||||
elif isinstance(item, dict) and handle_raw_dict_callback is not None:
|
||||
# Handle raw dict responses (e.g., from GPT-5 Codex)
|
||||
choice, index = handle_raw_dict_callback(item=item, index=index)
|
||||
|
||||
@ -64,12 +64,10 @@ def duration_in_seconds(duration: str) -> int:
|
||||
now = time.time()
|
||||
current_time = datetime.fromtimestamp(now)
|
||||
|
||||
if current_time.month == 12:
|
||||
target_year = current_time.year + 1
|
||||
target_month = 1
|
||||
else:
|
||||
target_year = current_time.year
|
||||
target_month = current_time.month + value
|
||||
# Calculate target month and year, handling overflow past December
|
||||
total_months = current_time.month - 1 + value # 0-indexed months
|
||||
target_year = current_time.year + total_months // 12
|
||||
target_month = total_months % 12 + 1 # back to 1-indexed
|
||||
|
||||
# Determine the day to set for next month
|
||||
target_day = current_time.day
|
||||
|
||||
@ -75,6 +75,53 @@ def _redact_responses_api_output(output_items):
|
||||
summary_item.text = "redacted-by-litellm"
|
||||
|
||||
|
||||
def _redact_standard_logging_object(model_call_details: dict):
|
||||
"""Redact messages and response inside standard_logging_object if present."""
|
||||
standard_logging_object = model_call_details.get("standard_logging_object")
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
redacted_str = "redacted-by-litellm"
|
||||
|
||||
if standard_logging_object.get("messages") is not None:
|
||||
standard_logging_object["messages"] = [
|
||||
{"role": "user", "content": redacted_str}
|
||||
]
|
||||
|
||||
response = standard_logging_object.get("response")
|
||||
if response is not None:
|
||||
if isinstance(response, dict) and "output" in response:
|
||||
# ResponsesAPIResponse format - redact content in output items
|
||||
if isinstance(response.get("output"), list):
|
||||
for output_item in response["output"]:
|
||||
if isinstance(output_item, dict) and "content" in output_item:
|
||||
if isinstance(output_item["content"], list):
|
||||
for content_item in output_item["content"]:
|
||||
if (
|
||||
isinstance(content_item, dict)
|
||||
and "text" in content_item
|
||||
):
|
||||
content_item["text"] = redacted_str
|
||||
elif isinstance(response, dict) and "choices" in response:
|
||||
# ModelResponse dict format - redact content in choices
|
||||
if isinstance(response.get("choices"), list):
|
||||
for choice in response["choices"]:
|
||||
if isinstance(choice, dict):
|
||||
if "message" in choice and isinstance(choice["message"], dict):
|
||||
choice["message"]["content"] = redacted_str
|
||||
if "audio" in choice["message"]:
|
||||
choice["message"]["audio"] = None
|
||||
elif "delta" in choice and isinstance(choice["delta"], dict):
|
||||
choice["delta"]["content"] = redacted_str
|
||||
if "audio" in choice["delta"]:
|
||||
choice["delta"]["audio"] = None
|
||||
elif isinstance(response, str):
|
||||
standard_logging_object["response"] = redacted_str
|
||||
else:
|
||||
# For other formats (empty dict, None, etc.), use simple text format
|
||||
standard_logging_object["response"] = {"text": redacted_str}
|
||||
|
||||
|
||||
def perform_redaction(model_call_details: dict, result):
|
||||
"""
|
||||
Performs the actual redaction on the logging object and result.
|
||||
|
||||
@ -51,6 +51,7 @@ from litellm.types.llms.openai import (
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
CompletionTokensDetailsWrapper,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
@ -63,6 +64,7 @@ from litellm.utils import (
|
||||
has_tool_call_blocks,
|
||||
last_assistant_with_tool_calls_has_no_thinking_blocks,
|
||||
supports_reasoning,
|
||||
token_counter,
|
||||
)
|
||||
|
||||
from ..common_utils import (
|
||||
@ -1637,7 +1639,11 @@ class AmazonConverseConfig(BaseConfig):
|
||||
thinking_blocks_list.append(_redacted_block)
|
||||
return thinking_blocks_list
|
||||
|
||||
def _transform_usage(self, usage: ConverseTokenUsageBlock) -> Usage:
|
||||
def _transform_usage(
|
||||
self,
|
||||
usage: ConverseTokenUsageBlock,
|
||||
reasoning_content: Optional[str] = None,
|
||||
) -> Usage:
|
||||
input_tokens = usage["inputTokens"]
|
||||
output_tokens = usage["outputTokens"]
|
||||
total_tokens = usage["totalTokens"]
|
||||
@ -1654,6 +1660,19 @@ class AmazonConverseConfig(BaseConfig):
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
reasoning_tokens = (
|
||||
token_counter(text=reasoning_content, count_response_tokens=True)
|
||||
if reasoning_content
|
||||
else 0
|
||||
)
|
||||
completion_tokens_details = CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
text_tokens=(
|
||||
output_tokens - reasoning_tokens
|
||||
if reasoning_tokens > 0
|
||||
else output_tokens
|
||||
),
|
||||
)
|
||||
openai_usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
@ -1661,6 +1680,7 @@ class AmazonConverseConfig(BaseConfig):
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
)
|
||||
return openai_usage
|
||||
|
||||
@ -1997,7 +2017,10 @@ class AmazonConverseConfig(BaseConfig):
|
||||
chat_completion_message["tool_calls"] = filtered_tools
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
usage = self._transform_usage(completion_response["usage"])
|
||||
usage = self._transform_usage(
|
||||
completion_response["usage"],
|
||||
reasoning_content=chat_completion_message.get("reasoning_content"),
|
||||
)
|
||||
|
||||
## HANDLE TOOL CALLS
|
||||
_message = Message(**chat_completion_message)
|
||||
|
||||
@ -429,8 +429,11 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
||||
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
|
||||
)
|
||||
|
||||
base = api_base.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[: -len("/v1")]
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/accounts/{account_id}/models",
|
||||
url=f"{base}/v1/accounts/{account_id}/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
|
||||
@ -583,35 +583,17 @@ class SagemakerLLM(BaseAWSLLM):
|
||||
### BOTO3 INIT
|
||||
import boto3
|
||||
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
# Use _load_credentials to support role assumption (aws_role_name, aws_session_name)
|
||||
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||
|
||||
if aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automaticaly reads env variables
|
||||
|
||||
# we need to read region name from env
|
||||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or aws_region_name # get region from config file if specified
|
||||
or "us-west-2" # default to us-west-2 if region not specified
|
||||
)
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
region_name=region_name,
|
||||
)
|
||||
# Create boto3 session with the loaded credentials
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
client = session.client(service_name="sagemaker-runtime")
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
@ -628,7 +610,9 @@ class SagemakerLLM(BaseAWSLLM):
|
||||
#### EMBEDDING LOGIC
|
||||
# Transform request based on model type
|
||||
provider_config = SagemakerEmbeddingConfig.get_model_config(model)
|
||||
request_data = provider_config.transform_embedding_request(model, input, optional_params, {})
|
||||
request_data = provider_config.transform_embedding_request(
|
||||
model, input, optional_params, {}
|
||||
)
|
||||
data = json.dumps(request_data).encode("utf-8")
|
||||
|
||||
## LOGGING
|
||||
@ -673,19 +657,19 @@ class SagemakerLLM(BaseAWSLLM):
|
||||
)
|
||||
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
|
||||
|
||||
# Transform response based on model type
|
||||
from httpx import Response as HttpxResponse
|
||||
|
||||
|
||||
# Create a mock httpx Response object for the transformation
|
||||
mock_response = HttpxResponse(
|
||||
status_code=200,
|
||||
content=json.dumps(response).encode('utf-8'),
|
||||
headers={"content-type": "application/json"}
|
||||
content=json.dumps(response).encode("utf-8"),
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
|
||||
|
||||
model_response = EmbeddingResponse()
|
||||
|
||||
|
||||
# Use the request_data that was already transformed above
|
||||
return provider_config.transform_embedding_response(
|
||||
model=model,
|
||||
@ -695,5 +679,5 @@ class SagemakerLLM(BaseAWSLLM):
|
||||
api_key=None,
|
||||
request_data=request_data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params or {}
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
|
||||
@ -593,6 +593,10 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||
raise e
|
||||
|
||||
|
||||
# Keys that LiteLLM consumes internally and must never be forwarded to the
|
||||
_LITELLM_INTERNAL_EXTRA_BODY_KEYS: frozenset = frozenset({"cache", "tags"})
|
||||
|
||||
|
||||
def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None:
|
||||
"""Pop extra_body from optional_params and shallow-merge into data, deep-merging dict values."""
|
||||
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
||||
|
||||
@ -718,6 +718,7 @@ if MCP_AVAILABLE:
|
||||
|
||||
Checks both the full tool name and unprefixed version (without server prefix).
|
||||
This allows users to configure simple tool names regardless of prefixing.
|
||||
Comparison is case-insensitive to handle OpenAPI operationIds that may be in camelCase.
|
||||
|
||||
Args:
|
||||
tool_name: The tool name to check (may be prefixed like "server-tool_name")
|
||||
@ -730,13 +731,15 @@ if MCP_AVAILABLE:
|
||||
split_server_prefix_from_name,
|
||||
)
|
||||
|
||||
# Check if the full name is in the list
|
||||
if tool_name in filter_list:
|
||||
# Normalize filter list to lowercase for case-insensitive comparison
|
||||
filter_list_lower = [f.lower() for f in filter_list]
|
||||
|
||||
if tool_name.lower() in filter_list_lower:
|
||||
return True
|
||||
|
||||
# Check if the unprefixed name is in the list
|
||||
# Check if the unprefixed name is in the list (case-insensitive)
|
||||
unprefixed_name, _ = split_server_prefix_from_name(tool_name)
|
||||
return unprefixed_name in filter_list
|
||||
return unprefixed_name.lower() in filter_list_lower
|
||||
|
||||
def filter_tools_by_allowed_tools(
|
||||
tools: List[MCPTool],
|
||||
|
||||
@ -111,12 +111,19 @@ def get_key_models(
|
||||
if SpecialModelNames.all_team_models.value in all_models:
|
||||
all_models = list(user_api_key_dict.team_models) # copy to avoid mutating cached objects
|
||||
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||
all_models = proxy_model_list
|
||||
all_models = list(proxy_model_list) # copy to avoid mutating caller's list
|
||||
if include_model_access_groups:
|
||||
all_models.extend(model_access_groups.keys())
|
||||
|
||||
all_models = _get_models_from_access_groups(
|
||||
model_access_groups=model_access_groups, all_models=all_models
|
||||
model_access_groups=model_access_groups,
|
||||
all_models=all_models,
|
||||
include_model_access_groups=include_model_access_groups,
|
||||
)
|
||||
|
||||
# deduplicate while preserving order
|
||||
all_models = list(dict.fromkeys(all_models))
|
||||
|
||||
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
@ -140,8 +147,8 @@ def get_team_models(
|
||||
all_models_set.update(team_models)
|
||||
if SpecialModelNames.all_proxy_models.value in all_models_set:
|
||||
all_models_set.update(proxy_model_list)
|
||||
|
||||
all_models = list(all_models_set)
|
||||
if include_model_access_groups:
|
||||
all_models_set.update(model_access_groups.keys())
|
||||
|
||||
all_models = _get_models_from_access_groups(
|
||||
model_access_groups=model_access_groups,
|
||||
@ -149,6 +156,9 @@ def get_team_models(
|
||||
include_model_access_groups=include_model_access_groups,
|
||||
)
|
||||
|
||||
# deduplicate while preserving order
|
||||
all_models = list(dict.fromkeys(all_models))
|
||||
|
||||
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
@ -142,17 +142,47 @@ async def get_credentials(
|
||||
tags=["credential management"],
|
||||
response_model=CredentialItem,
|
||||
)
|
||||
async def get_credential_by_name(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
credential_name: str = Path(..., description="The credential name, percent-decoded; may contain slashes"),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[BETA] endpoint. This might change unexpectedly.
|
||||
"""
|
||||
try:
|
||||
for credential in litellm.credential_list:
|
||||
if credential.credential_name == credential_name:
|
||||
masked_credential = CredentialItem(
|
||||
credential_name=credential.credential_name,
|
||||
credential_values=_get_masked_values(
|
||||
credential.credential_values,
|
||||
unmasked_length=4,
|
||||
number_of_asterisks=4,
|
||||
),
|
||||
credential_info=credential.credential_info,
|
||||
)
|
||||
return masked_credential
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credential not found. Got credential name: " + credential_name,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(e)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/credentials/by_model/{model_id}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["credential management"],
|
||||
response_model=CredentialItem,
|
||||
)
|
||||
async def get_credential(
|
||||
async def get_credential_by_model(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
credential_name: str = Path(..., description="The credential name, percent-decoded; may contain slashes"),
|
||||
model_id: Optional[str] = None,
|
||||
model_id: str = Path(..., description="The model ID to look up credentials for"),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
@ -161,48 +191,25 @@ async def get_credential(
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
try:
|
||||
if model_id:
|
||||
if llm_router is None:
|
||||
raise HTTPException(status_code=500, detail="LLM router not found")
|
||||
model = llm_router.get_deployment(model_id)
|
||||
if model is None:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
credential_values = llm_router.get_deployment_credentials(model_id)
|
||||
if credential_values is None:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
masked_credential_values = _get_masked_values(
|
||||
credential_values,
|
||||
unmasked_length=4,
|
||||
number_of_asterisks=4,
|
||||
)
|
||||
credential = CredentialItem(
|
||||
credential_name="{}-credential-{}".format(model.model_name, model_id),
|
||||
credential_values=masked_credential_values,
|
||||
credential_info={},
|
||||
)
|
||||
# return credential object
|
||||
return credential
|
||||
elif credential_name:
|
||||
for credential in litellm.credential_list:
|
||||
if credential.credential_name == credential_name:
|
||||
masked_credential = CredentialItem(
|
||||
credential_name=credential.credential_name,
|
||||
credential_values=_get_masked_values(
|
||||
credential.credential_values,
|
||||
unmasked_length=4,
|
||||
number_of_asterisks=4,
|
||||
),
|
||||
credential_info=credential.credential_info,
|
||||
)
|
||||
return masked_credential
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credential not found. Got credential name: " + credential_name,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Credential name or model ID required"
|
||||
)
|
||||
if llm_router is None:
|
||||
raise HTTPException(status_code=500, detail="LLM router not found")
|
||||
model = llm_router.get_deployment(model_id)
|
||||
if model is None:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
credential_values = llm_router.get_deployment_credentials(model_id)
|
||||
if credential_values is None:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
masked_credential_values = _get_masked_values(
|
||||
credential_values,
|
||||
unmasked_length=4,
|
||||
number_of_asterisks=4,
|
||||
)
|
||||
credential = CredentialItem(
|
||||
credential_name="{}-credential-{}".format(model.model_name, model_id),
|
||||
credential_values=masked_credential_values,
|
||||
credential_info={},
|
||||
)
|
||||
return credential
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(e)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
@ -2827,21 +2827,6 @@ async def validate_membership(
|
||||
)
|
||||
|
||||
|
||||
def _unfurl_all_proxy_models(
|
||||
team_info: LiteLLM_TeamTable, llm_router: Router
|
||||
) -> LiteLLM_TeamTable:
|
||||
if (
|
||||
SpecialModelNames.all_proxy_models.value in team_info.models
|
||||
and llm_router is not None
|
||||
):
|
||||
team_models: set[str] = set() # make set to avoid duplicates
|
||||
for model in team_info.models:
|
||||
if model != SpecialModelNames.all_proxy_models.value:
|
||||
team_models.add(model)
|
||||
for model in llm_router.get_model_names():
|
||||
team_models.add(model)
|
||||
team_info.models = list(team_models)
|
||||
return team_info
|
||||
|
||||
|
||||
async def _add_team_member_budget_table(
|
||||
@ -2972,9 +2957,6 @@ async def team_info(
|
||||
team_info_response_object=_team_info,
|
||||
)
|
||||
|
||||
# ## UNFURL 'all-proxy-models' into the team_info.models list ##
|
||||
# if llm_router is not None:
|
||||
# _team_info = _unfurl_all_proxy_models(_team_info, llm_router)
|
||||
response_object = TeamInfoResponseObject(
|
||||
team_id=team_id,
|
||||
team_info=_team_info,
|
||||
|
||||
@ -2059,7 +2059,8 @@ class InitPassThroughEndpointHelpers:
|
||||
"""
|
||||
## CHECK IF MAPPED PASS THROUGH ENDPOINT
|
||||
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value:
|
||||
if route.startswith(mapped_route):
|
||||
full_mapped_route = InitPassThroughEndpointHelpers._build_full_path_with_root(mapped_route)
|
||||
if route.startswith(full_mapped_route):
|
||||
return True
|
||||
|
||||
# Fast path: check if any registered route key contains this path
|
||||
|
||||
@ -1461,11 +1461,21 @@ async def _get_spend_report_for_time_range(
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
responses={
|
||||
200: {
|
||||
"cost": {
|
||||
"description": "The calculated cost",
|
||||
"example": 0.0,
|
||||
"type": "float",
|
||||
}
|
||||
"description": "The calculated cost",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cost": {
|
||||
"type": "number",
|
||||
"description": "The calculated cost",
|
||||
"example": 0.0,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@ -292,21 +292,21 @@ class LiteLLMCompletionResponsesConfig:
|
||||
)
|
||||
_messages = litellm_completion_request.get("messages") or []
|
||||
session_messages = chat_completion_session.get("messages") or []
|
||||
|
||||
|
||||
# If session messages are empty (e.g., no database in test environment),
|
||||
# we still need to process the new input messages
|
||||
# Store original _messages before combining for safety check
|
||||
original_new_messages = _messages.copy() if _messages else []
|
||||
|
||||
|
||||
combined_messages = session_messages + _messages
|
||||
|
||||
|
||||
# Fix: Ensure tool_results have corresponding tool_calls in previous assistant message
|
||||
# Pass tools parameter to help reconstruct tool_calls if not in cache
|
||||
tools = litellm_completion_request.get("tools") or []
|
||||
combined_messages = LiteLLMCompletionResponsesConfig._ensure_tool_results_have_corresponding_tool_calls(
|
||||
messages=combined_messages, tools=tools
|
||||
)
|
||||
|
||||
|
||||
# Safety check: Ensure we don't end up with empty messages
|
||||
# This can happen when using previous_response_id without a database (e.g., in tests)
|
||||
# and session messages are empty but new input messages exist
|
||||
@ -340,7 +340,7 @@ class LiteLLMCompletionResponsesConfig:
|
||||
"custom_llm_provider", ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
litellm_completion_request["messages"] = combined_messages
|
||||
litellm_completion_request["litellm_trace_id"] = chat_completion_session.get(
|
||||
"litellm_session_id"
|
||||
@ -386,10 +386,45 @@ class LiteLLMCompletionResponsesConfig:
|
||||
if call_id_raw:
|
||||
existing_tool_call_ids.add(str(call_id_raw))
|
||||
|
||||
#########################################################
|
||||
# Merge consecutive function_call items into a single assistant
|
||||
# message. Anthropic requires that all tool_use blocks appear in
|
||||
# ONE assistant message immediately followed by the tool_result
|
||||
# blocks. Without this merging, each function_call creates its own
|
||||
# assistant message, producing back-to-back assistant messages that
|
||||
# Anthropic rejects with "tool_use ids were found without
|
||||
# tool_result blocks immediately after".
|
||||
#########################################################
|
||||
if messages:
|
||||
last_msg = messages[-1]
|
||||
last_role = (
|
||||
last_msg.get("role")
|
||||
if isinstance(last_msg, dict)
|
||||
else getattr(last_msg, "role", None)
|
||||
)
|
||||
if last_role == "assistant":
|
||||
for new_msg in chat_completion_messages:
|
||||
new_role = (
|
||||
new_msg.get("role")
|
||||
if isinstance(new_msg, dict)
|
||||
else getattr(new_msg, "role", None)
|
||||
)
|
||||
if new_role == "assistant":
|
||||
new_tcs = (
|
||||
new_msg.get("tool_calls")
|
||||
if isinstance(new_msg, dict)
|
||||
else getattr(new_msg, "tool_calls", None)
|
||||
) or []
|
||||
for tc in new_tcs:
|
||||
LiteLLMCompletionResponsesConfig._add_tool_call_to_assistant(
|
||||
last_msg, tc
|
||||
)
|
||||
continue
|
||||
|
||||
#########################################################
|
||||
# If Input Item is a Tool Call Output, add it to the tool_call_output_messages list
|
||||
# preserving the ordering of tool call outputs. Some models require the tool
|
||||
# result to immediately follow the assistant tool call.
|
||||
# preserving the ordering of tool call outputs. Some models require the tool
|
||||
# result to immediately follow the assistant tool call.
|
||||
#########################################################
|
||||
if LiteLLMCompletionResponsesConfig._is_input_item_tool_call_output(
|
||||
input_item=_input
|
||||
@ -774,14 +809,14 @@ class LiteLLMCompletionResponsesConfig:
|
||||
]:
|
||||
"""
|
||||
Ensure that tool_result messages have corresponding tool_calls in the previous assistant message.
|
||||
|
||||
|
||||
This is critical for Anthropic API which requires that each tool_result block has a
|
||||
corresponding tool_use block in the previous assistant message.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of messages that may include tool_result messages
|
||||
tools: Optional list of tools that can be used to reconstruct tool_calls if not in cache
|
||||
|
||||
|
||||
Returns:
|
||||
List of messages with tool_calls added to assistant messages when needed
|
||||
"""
|
||||
@ -801,18 +836,18 @@ class LiteLLMCompletionResponsesConfig:
|
||||
]
|
||||
] = list(copy.deepcopy(messages))
|
||||
messages_to_remove = []
|
||||
|
||||
|
||||
# Count non-tool messages to avoid removing all messages
|
||||
# This prevents empty messages list when using previous_response_id without a database
|
||||
non_tool_messages_count = sum(
|
||||
1 for msg in fixed_messages if msg.get("role") != "tool"
|
||||
)
|
||||
|
||||
|
||||
for i, message in enumerate(fixed_messages):
|
||||
# Only process tool messages - check role first to narrow the type
|
||||
if message.get("role") != "tool":
|
||||
continue
|
||||
|
||||
|
||||
# At this point, we know it's a tool message, so it should have tool_call_id
|
||||
# Use get() with default to safely access tool_call_id
|
||||
tool_call_id_raw = (
|
||||
@ -823,11 +858,11 @@ class LiteLLMCompletionResponsesConfig:
|
||||
tool_call_id: str = (
|
||||
str(tool_call_id_raw) if tool_call_id_raw is not None else ""
|
||||
)
|
||||
|
||||
|
||||
prev_assistant_idx = LiteLLMCompletionResponsesConfig._find_previous_assistant_idx(
|
||||
fixed_messages, i
|
||||
)
|
||||
|
||||
|
||||
# Try to recover empty tool_call_id from previous assistant message
|
||||
if not tool_call_id and prev_assistant_idx is not None:
|
||||
prev_assistant = fixed_messages[prev_assistant_idx]
|
||||
@ -842,7 +877,7 @@ class LiteLLMCompletionResponsesConfig:
|
||||
message_dict["tool_call_id"] = tool_call_id
|
||||
elif hasattr(message, "tool_call_id"):
|
||||
setattr(message, "tool_call_id", tool_call_id)
|
||||
|
||||
|
||||
# Only remove messages with empty tool_call_id if we have other non-tool messages
|
||||
# This prevents ending up with an empty messages list when using previous_response_id
|
||||
# without a database (e.g., in tests where session messages are empty)
|
||||
@ -854,7 +889,7 @@ class LiteLLMCompletionResponsesConfig:
|
||||
# If no non-tool messages, keep the tool message even with empty call_id
|
||||
# The API will return a proper error message about the missing tool_use block
|
||||
continue
|
||||
|
||||
|
||||
# Check if the previous assistant message has the corresponding tool_call
|
||||
# This needs to run for ALL tool messages with a valid tool_call_id,
|
||||
# not just those that had an empty tool_call_id initially
|
||||
@ -863,12 +898,12 @@ class LiteLLMCompletionResponsesConfig:
|
||||
tool_calls = LiteLLMCompletionResponsesConfig._get_tool_calls_list(
|
||||
prev_assistant
|
||||
)
|
||||
|
||||
|
||||
if not LiteLLMCompletionResponsesConfig._check_tool_call_exists(
|
||||
tool_calls, tool_call_id
|
||||
):
|
||||
_tool_use_definition = TOOL_CALLS_CACHE.get_cache(key=tool_call_id)
|
||||
|
||||
|
||||
if not _tool_use_definition and tools:
|
||||
_tool_use_definition = LiteLLMCompletionResponsesConfig._reconstruct_tool_call_from_tools(
|
||||
tool_call_id, tools
|
||||
@ -891,11 +926,11 @@ class LiteLLMCompletionResponsesConfig:
|
||||
LiteLLMCompletionResponsesConfig._add_tool_call_to_assistant(
|
||||
prev_assistant, tool_call_chunk
|
||||
)
|
||||
|
||||
|
||||
# Remove messages with empty tool_call_id that couldn't be fixed
|
||||
for idx in reversed(messages_to_remove):
|
||||
fixed_messages.pop(idx)
|
||||
|
||||
|
||||
return fixed_messages
|
||||
|
||||
@staticmethod
|
||||
@ -1547,6 +1582,39 @@ class LiteLLMCompletionResponsesConfig:
|
||||
|
||||
return tool_call_dict
|
||||
|
||||
@staticmethod
|
||||
def convert_apply_patch_tool_call_to_chat_completion_tool_call(
|
||||
tool_call_item: Any,
|
||||
index: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert ResponseApplyPatchToolCall to ChatCompletionToolCallChunk format.
|
||||
|
||||
The operation (create_file / update_file / delete_file) is serialised
|
||||
as JSON so it appears in function.arguments, just like any other
|
||||
tool call.
|
||||
|
||||
Args:
|
||||
tool_call_item: ResponseApplyPatchToolCall object with call_id and operation
|
||||
index: The index of this tool call
|
||||
|
||||
Returns:
|
||||
Dictionary in ChatCompletionToolCallChunk format
|
||||
"""
|
||||
import json
|
||||
|
||||
operation_dict = tool_call_item.operation.model_dump()
|
||||
tool_call_dict: Dict[str, Any] = {
|
||||
"id": tool_call_item.call_id,
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"arguments": json.dumps(operation_dict),
|
||||
},
|
||||
"type": "function",
|
||||
"index": index,
|
||||
}
|
||||
return tool_call_dict
|
||||
|
||||
@staticmethod
|
||||
def transform_chat_completion_response_to_responses_api_response(
|
||||
request_input: Union[str, ResponseInputParam],
|
||||
|
||||
@ -5505,6 +5505,10 @@ class Router:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Always track the latest error so we raise the most
|
||||
# recent exception instead of the first one.
|
||||
original_exception = e
|
||||
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
remaining_retries = num_retries - current_attempt - 1
|
||||
@ -5519,6 +5523,24 @@ class Router:
|
||||
)
|
||||
else:
|
||||
_healthy_deployments = []
|
||||
|
||||
# Check if this error is non-retryable (e.g., 400 context
|
||||
# window exceeded). If so, raise immediately instead of
|
||||
# continuing the retry loop. Respect retry policy
|
||||
# precedence - only check when no retry policy applies.
|
||||
if not _retry_policy_applies:
|
||||
try:
|
||||
self.should_retry_this_error(
|
||||
error=e,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
regular_fallbacks=fallbacks,
|
||||
content_policy_fallbacks=content_policy_fallbacks,
|
||||
)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=e,
|
||||
remaining_retries=remaining_retries,
|
||||
|
||||
@ -498,20 +498,22 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||
|
||||
# get average latency or average ttft (depending on streaming/non-streaming)
|
||||
total: float = 0.0
|
||||
if (
|
||||
use_ttft = (
|
||||
request_kwargs is not None
|
||||
and request_kwargs.get("stream", None) is not None
|
||||
and request_kwargs["stream"] is True
|
||||
and len(item_ttft_latency) > 0
|
||||
):
|
||||
)
|
||||
if use_ttft:
|
||||
for _call_latency in item_ttft_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
item_latency = total / len(item_ttft_latency)
|
||||
else:
|
||||
for _call_latency in item_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
item_latency = total / len(item_latency)
|
||||
item_latency = total / len(item_latency)
|
||||
|
||||
# -------------- #
|
||||
# Debugging Logic
|
||||
|
||||
@ -458,6 +458,24 @@
|
||||
"interactions": true
|
||||
}
|
||||
},
|
||||
"charity_engine": {
|
||||
"display_name": "Charity Engine (`charity_engine`)",
|
||||
"url": "https://docs.litellm.ai/docs/providers/charity_engine",
|
||||
"endpoints": {
|
||||
"chat_completions": true,
|
||||
"messages": true,
|
||||
"responses": true,
|
||||
"embeddings": false,
|
||||
"image_generations": false,
|
||||
"audio_transcriptions": false,
|
||||
"audio_speech": false,
|
||||
"moderations": false,
|
||||
"batches": false,
|
||||
"rerank": false,
|
||||
"a2a": false,
|
||||
"interactions": false
|
||||
}
|
||||
},
|
||||
"chutes": {
|
||||
"display_name": "Chutes (`chutes`)",
|
||||
"endpoints": {
|
||||
|
||||
@ -44,19 +44,25 @@ def create_skill_zip(skill_name: str, unique_suffix: Optional[str] = None):
|
||||
skill_dir = test_dir / skill_name
|
||||
|
||||
# Create a zip file containing the skill directory
|
||||
# When unique_suffix is set, folder name must match skill name in SKILL.md (Anthropic requirement)
|
||||
zip_folder_name = f"{skill_name}-{unique_suffix}" if unique_suffix else skill_name
|
||||
zip_path = test_dir / f"{skill_name}.zip"
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.write(skill_dir, arcname=skill_name)
|
||||
|
||||
if unique_suffix is not None:
|
||||
# Rewrite SKILL.md with a unique name to avoid API conflicts
|
||||
# Rewrite SKILL.md with a unique name and use matching folder name
|
||||
skill_md = (skill_dir / "SKILL.md").read_text()
|
||||
skill_md = skill_md.replace(
|
||||
f"name: {skill_name}",
|
||||
f"name: {skill_name}-{unique_suffix}",
|
||||
f"name: {zip_folder_name}",
|
||||
)
|
||||
zf.writestr(f"{skill_name}/SKILL.md", skill_md)
|
||||
zf.writestr(f"{zip_folder_name}/SKILL.md", skill_md)
|
||||
# Add any other files in the skill dir (e.g. subdirs) under the new folder name
|
||||
for f in skill_dir.rglob("*"):
|
||||
if f.is_file() and f.name != "SKILL.md":
|
||||
rel = f.relative_to(skill_dir)
|
||||
zf.write(f, arcname=f"{zip_folder_name}/{rel}")
|
||||
else:
|
||||
zf.write(skill_dir, arcname=skill_name)
|
||||
zf.write(skill_dir / "SKILL.md", arcname=f"{skill_name}/SKILL.md")
|
||||
|
||||
try:
|
||||
|
||||
@ -1300,9 +1300,11 @@ def test_logging_async_cache_hit_sync_call(turn_off_message_logging):
|
||||
"redacted-by-litellm"
|
||||
== standard_logging_object["messages"][0]["content"]
|
||||
)
|
||||
assert {"text": "redacted-by-litellm"} == standard_logging_object[
|
||||
"response"
|
||||
]
|
||||
# response is a full ModelResponse dict (choices format) since d84e5e381acf
|
||||
assert (
|
||||
standard_logging_object["response"]["choices"][0]["message"]["content"]
|
||||
== "redacted-by-litellm"
|
||||
)
|
||||
|
||||
|
||||
def test_logging_standard_payload_failure_call():
|
||||
|
||||
@ -45,7 +45,8 @@ async def test_global_redaction_on():
|
||||
await asyncio.sleep(1)
|
||||
standard_logging_payload = test_custom_logger.logged_standard_logging_payload
|
||||
assert standard_logging_payload is not None
|
||||
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
|
||||
response = standard_logging_payload["response"]
|
||||
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
|
||||
assert standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
|
||||
print(
|
||||
"logged standard logging payload",
|
||||
@ -75,7 +76,8 @@ async def test_global_redaction_with_dynamic_params(turn_off_message_logging):
|
||||
)
|
||||
|
||||
if turn_off_message_logging is True:
|
||||
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
|
||||
response = standard_logging_payload["response"]
|
||||
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
|
||||
assert (
|
||||
standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
|
||||
)
|
||||
@ -108,7 +110,8 @@ async def test_global_redaction_off_with_dynamic_params(turn_off_message_logging
|
||||
json.dumps(standard_logging_payload, indent=2),
|
||||
)
|
||||
if turn_off_message_logging is True:
|
||||
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
|
||||
response = standard_logging_payload["response"]
|
||||
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
|
||||
assert (
|
||||
standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
|
||||
)
|
||||
@ -390,7 +393,8 @@ async def test_redaction_with_streaming_response():
|
||||
assert standard_logging_payload is not None
|
||||
|
||||
# Verify that redaction worked without pickle errors
|
||||
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
|
||||
response = standard_logging_payload["response"]
|
||||
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
|
||||
assert standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
|
||||
print(
|
||||
"logged standard logging payload for streaming with coroutine handling",
|
||||
@ -477,5 +481,6 @@ async def test_redaction_with_metadata_completion_api():
|
||||
|
||||
# Verify the helper function works correctly - with get_metadata_variable_name_from_kwargs,
|
||||
# the system checks the appropriate field for headers
|
||||
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
|
||||
response = standard_logging_payload["response"]
|
||||
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
|
||||
assert standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
|
||||
@ -56,3 +58,104 @@ async def test_dual_cache_async_batch_get_cache_rolls_back_redis_reservation_on_
|
||||
assert mock_async_batch_get_cache.call_count == 2
|
||||
assert "shared_a" not in dual_cache.last_redis_batch_access_time
|
||||
assert "shared_b" not in dual_cache.last_redis_batch_access_time
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dual_cache_async_set_cache_injects_default_in_memory_ttl():
|
||||
"""
|
||||
Test that async_set_cache injects default_in_memory_ttl into kwargs
|
||||
when no explicit ttl is provided, matching the sync set_cache behavior.
|
||||
|
||||
Regression test for: async_set_cache was missing the TTL injection that
|
||||
sync set_cache has, causing InMemoryCache to use its own default_ttl (600s)
|
||||
instead of DualCache's default_in_memory_ttl.
|
||||
"""
|
||||
in_memory_cache = InMemoryCache(default_ttl=600)
|
||||
dual_cache = DualCache(
|
||||
in_memory_cache=in_memory_cache,
|
||||
default_in_memory_ttl=60,
|
||||
)
|
||||
|
||||
before = time.time()
|
||||
await dual_cache.async_set_cache(key="test_key", value="test_value")
|
||||
after = time.time()
|
||||
|
||||
# The TTL stored should reflect default_in_memory_ttl (60s), not
|
||||
# InMemoryCache's default_ttl (600s)
|
||||
expiry = in_memory_cache.ttl_dict["test_key"]
|
||||
assert expiry >= before + 60
|
||||
assert expiry <= after + 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dual_cache_async_set_cache_respects_explicit_ttl():
|
||||
"""
|
||||
Test that async_set_cache does NOT override an explicitly provided ttl.
|
||||
"""
|
||||
in_memory_cache = InMemoryCache(default_ttl=600)
|
||||
dual_cache = DualCache(
|
||||
in_memory_cache=in_memory_cache,
|
||||
default_in_memory_ttl=60,
|
||||
)
|
||||
|
||||
before = time.time()
|
||||
await dual_cache.async_set_cache(key="test_key", value="test_value", ttl=30)
|
||||
after = time.time()
|
||||
|
||||
# The explicit ttl=30 should be used, not default_in_memory_ttl (60)
|
||||
expiry = in_memory_cache.ttl_dict["test_key"]
|
||||
assert expiry >= before + 30
|
||||
assert expiry <= after + 30
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dual_cache_async_set_cache_pipeline_injects_default_in_memory_ttl():
|
||||
"""
|
||||
Test that async_set_cache_pipeline injects default_in_memory_ttl into kwargs
|
||||
when no explicit ttl is provided.
|
||||
"""
|
||||
in_memory_cache = InMemoryCache(default_ttl=600)
|
||||
dual_cache = DualCache(
|
||||
in_memory_cache=in_memory_cache,
|
||||
default_in_memory_ttl=60,
|
||||
)
|
||||
|
||||
cache_list = [("key_a", "value_a"), ("key_b", "value_b")]
|
||||
|
||||
before = time.time()
|
||||
await dual_cache.async_set_cache_pipeline(cache_list=cache_list)
|
||||
after = time.time()
|
||||
|
||||
for key in ["key_a", "key_b"]:
|
||||
expiry = in_memory_cache.ttl_dict[key]
|
||||
assert expiry >= before + 60
|
||||
assert expiry <= after + 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dual_cache_sync_and_async_set_cache_use_same_ttl():
|
||||
"""
|
||||
Test that sync set_cache and async async_set_cache produce the same TTL
|
||||
when no explicit ttl is provided, ensuring parity between the two paths.
|
||||
"""
|
||||
in_memory_sync = InMemoryCache(default_ttl=600)
|
||||
dual_cache_sync = DualCache(
|
||||
in_memory_cache=in_memory_sync,
|
||||
default_in_memory_ttl=60,
|
||||
)
|
||||
|
||||
in_memory_async = InMemoryCache(default_ttl=600)
|
||||
dual_cache_async = DualCache(
|
||||
in_memory_cache=in_memory_async,
|
||||
default_in_memory_ttl=60,
|
||||
)
|
||||
|
||||
dual_cache_sync.set_cache(key="test_key", value="test_value")
|
||||
await dual_cache_async.async_set_cache(key="test_key", value="test_value")
|
||||
|
||||
sync_expiry = in_memory_sync.ttl_dict["test_key"]
|
||||
async_expiry = in_memory_async.ttl_dict["test_key"]
|
||||
|
||||
# Both should use default_in_memory_ttl=60, so their expiry times
|
||||
# should be within a small tolerance of each other
|
||||
assert abs(sync_expiry - async_expiry) < 1.0
|
||||
|
||||
@ -738,7 +738,58 @@ def test_response_completed_with_message_only_emits_stop_finish_reason():
|
||||
)
|
||||
|
||||
|
||||
def test_function_call_done_does_not_emit_finish_reason():
|
||||
|
||||
def test_response_completed_preserves_usage_with_cached_tokens():
|
||||
"""
|
||||
Test that response.completed correctly translates Responses API usage
|
||||
(input_tokens_details) to chat completion usage (prompt_tokens_details).
|
||||
|
||||
This is a regression test for an issue where streaming with models that
|
||||
use the Responses API bridge (e.g. gpt-5.2-codex) would drop
|
||||
prompt_tokens_details, causing cached_tokens to always be None.
|
||||
"""
|
||||
from litellm.completion_extras.litellm_responses_transformation.transformation import (
|
||||
OpenAiResponsesToChatCompletionStreamIterator,
|
||||
)
|
||||
|
||||
iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True)
|
||||
|
||||
chunk = {
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": "resp_789",
|
||||
"status": "completed",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_abc",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Six"}],
|
||||
"status": "completed",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 1226,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 1231,
|
||||
"input_tokens_details": {"cached_tokens": 1024},
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = iterator.chunk_parser(chunk)
|
||||
|
||||
assert result.usage is not None, "usage should be set on response.completed chunk"
|
||||
assert result.usage.prompt_tokens == 1226, "prompt_tokens should map from input_tokens"
|
||||
assert result.usage.completion_tokens == 5, "completion_tokens should map from output_tokens"
|
||||
assert result.usage.prompt_tokens_details is not None, "prompt_tokens_details should be set"
|
||||
assert result.usage.prompt_tokens_details.cached_tokens == 1024, (
|
||||
"cached_tokens should be preserved from input_tokens_details"
|
||||
)
|
||||
|
||||
|
||||
def test_function_call_done_emits_is_finished():
|
||||
"""
|
||||
Test that OUTPUT_ITEM_DONE for a function_call does NOT emit finish_reason.
|
||||
The response.completed event handles the terminal finish_reason correctly.
|
||||
@ -1327,6 +1378,138 @@ def test_transform_response_preserves_annotations():
|
||||
print("✓ Annotations from Responses API are correctly preserved in Chat Completions format")
|
||||
|
||||
|
||||
def test_apply_patch_tool_call_converted_to_chat_completion_tool_call():
|
||||
"""
|
||||
Test that ResponseApplyPatchToolCall items from the Responses API are
|
||||
correctly converted to ChatCompletions-style tool calls by the bridge.
|
||||
|
||||
This is a regression test for a bug where litellm.completion() with a
|
||||
responses/ model prefix crashed when the model returned an
|
||||
apply_patch_call, because _convert_response_output_to_choices did not
|
||||
handle ResponseApplyPatchToolCall items. The model DID use the tool,
|
||||
but the bridge silently dropped it (or raised an error), while the
|
||||
native litellm.responses() path worked correctly.
|
||||
"""
|
||||
import json
|
||||
from unittest.mock import Mock
|
||||
|
||||
from openai.types.responses.response_apply_patch_tool_call import (
|
||||
OperationCreateFile,
|
||||
)
|
||||
from openai.types.responses.response_output_item import (
|
||||
ResponseApplyPatchToolCall,
|
||||
)
|
||||
|
||||
from litellm.completion_extras.litellm_responses_transformation.transformation import (
|
||||
LiteLLMResponsesTransformationHandler,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
InputTokensDetails,
|
||||
OutputTokensDetails,
|
||||
ResponseAPIUsage,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
handler = LiteLLMResponsesTransformationHandler()
|
||||
|
||||
# Build an apply_patch_call item like the model would return
|
||||
operation = OperationCreateFile(
|
||||
diff="--- /dev/null\n+++ b/hello.py\n@@ -0,0 +1 @@\n+print('hello world')\n",
|
||||
path="hello.py",
|
||||
type="create_file",
|
||||
)
|
||||
apply_patch_item = ResponseApplyPatchToolCall(
|
||||
id="apc_001",
|
||||
call_id="call_patch_hello",
|
||||
operation=operation,
|
||||
status="completed",
|
||||
type="apply_patch_call",
|
||||
)
|
||||
|
||||
# Minimal usage
|
||||
usage = ResponseAPIUsage(
|
||||
input_tokens=30,
|
||||
input_tokens_details=InputTokensDetails(cached_tokens=0),
|
||||
output_tokens=40,
|
||||
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
|
||||
total_tokens=70,
|
||||
)
|
||||
|
||||
raw_response = ResponsesAPIResponse(
|
||||
id="resp_apply_patch_test",
|
||||
created_at=1234567890,
|
||||
error=None,
|
||||
incomplete_details=None,
|
||||
instructions=None,
|
||||
metadata={},
|
||||
model="gpt-5.2-codex",
|
||||
object="response",
|
||||
output=[apply_patch_item],
|
||||
parallel_tool_calls=True,
|
||||
temperature=1.0,
|
||||
tool_choice="auto",
|
||||
tools=[],
|
||||
top_p=1.0,
|
||||
max_output_tokens=None,
|
||||
previous_response_id=None,
|
||||
reasoning=None,
|
||||
status="completed",
|
||||
text=None,
|
||||
truncation="disabled",
|
||||
usage=usage,
|
||||
user=None,
|
||||
store=True,
|
||||
background=False,
|
||||
)
|
||||
|
||||
model_response = ModelResponse(
|
||||
id="chatcmpl-apply-patch",
|
||||
created=1234567890,
|
||||
model=None,
|
||||
object="chat.completion",
|
||||
choices=[],
|
||||
usage=Usage(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||||
)
|
||||
|
||||
logging_obj = Mock()
|
||||
|
||||
result = handler.transform_response(
|
||||
model="gpt-5.2-codex",
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data={"model": "gpt-5.2-codex"},
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Create hello.py"},
|
||||
],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
encoding=Mock(),
|
||||
)
|
||||
|
||||
# Should have exactly one choice with finish_reason="tool_calls"
|
||||
assert len(result.choices) == 1, f"Expected 1 choice, got {len(result.choices)}"
|
||||
|
||||
choice = result.choices[0]
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
|
||||
# The choice should contain one tool call for apply_patch
|
||||
tool_calls = choice.message.tool_calls
|
||||
assert tool_calls is not None, "tool_calls should not be None"
|
||||
assert len(tool_calls) == 1, f"Expected 1 tool_call, got {len(tool_calls)}"
|
||||
|
||||
tc = tool_calls[0]
|
||||
assert tc["id"] == "call_patch_hello"
|
||||
assert tc["type"] == "function"
|
||||
assert tc["function"]["name"] == "apply_patch"
|
||||
|
||||
# The operation should be serialised as JSON in arguments
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
assert args["type"] == "create_file"
|
||||
assert args["path"] == "hello.py"
|
||||
assert "print('hello world')" in args["diff"]
|
||||
def test_multi_tool_call_stream_no_premature_finish():
|
||||
"""
|
||||
Regression test for multi-tool-call streaming bug.
|
||||
@ -1778,3 +1961,35 @@ def test_parallel_tool_calls_comprehensive_streaming_integration():
|
||||
)
|
||||
|
||||
print("✓ Parallel tool calls with split argument deltas stream correctly end-to-end")
|
||||
|
||||
|
||||
def test_map_optional_params_preserves_reasoning_summary():
|
||||
"""Test that reasoning_effort dict with summary field is preserved.
|
||||
|
||||
Regression test for: User reported that summary field was being dropped
|
||||
when routing to Responses API. The dict format should be fully preserved.
|
||||
"""
|
||||
from litellm.completion_extras.litellm_responses_transformation.transformation import (
|
||||
LiteLLMResponsesTransformationHandler,
|
||||
)
|
||||
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
|
||||
|
||||
handler = LiteLLMResponsesTransformationHandler()
|
||||
|
||||
optional_params = {
|
||||
"stream": False,
|
||||
"tools": [{"type": "function", "function": {"name": "test_tool"}}],
|
||||
"tool_choice": "auto",
|
||||
"reasoning_effort": {"effort": "high", "summary": "detailed"},
|
||||
}
|
||||
|
||||
responses_api_request = ResponsesAPIOptionalRequestParams()
|
||||
handler._map_optional_params_to_responses_api_request(
|
||||
optional_params, responses_api_request
|
||||
)
|
||||
|
||||
# Verify reasoning_effort dict with summary was fully preserved
|
||||
assert "reasoning" in responses_api_request
|
||||
assert responses_api_request["reasoning"] == {"effort": "high", "summary": "detailed"}
|
||||
assert responses_api_request["reasoning"]["effort"] == "high"
|
||||
assert responses_api_request["reasoning"]["summary"] == "detailed"
|
||||
|
||||
@ -192,6 +192,23 @@ def test_azure_gpt5_1_series_temperature_handling(config: AzureOpenAIGPT5Config)
|
||||
assert params["temperature"] == 0.6
|
||||
|
||||
|
||||
def test_azure_gpt5_4_drops_reasoning_effort_when_tools_present(config: AzureOpenAIGPT5Config):
|
||||
"""Azure Chat Completions: gpt-5.4+ drops reasoning_effort when tools are present.
|
||||
|
||||
OpenAI routes tools+reasoning to Responses API; Azure does not, so we drop reasoning_effort.
|
||||
"""
|
||||
tools = [{"type": "function", "function": {"name": "test", "description": "test"}}]
|
||||
params = config.map_openai_params(
|
||||
non_default_params={"reasoning_effort": "high", "tools": tools},
|
||||
optional_params={},
|
||||
model="gpt5_series/gpt-5.4",
|
||||
drop_params=False,
|
||||
api_version="2024-05-01-preview",
|
||||
)
|
||||
assert "reasoning_effort" not in params
|
||||
assert params["tools"] == tools
|
||||
|
||||
|
||||
def test_azure_gpt5_reasoning_effort_none_error(config: AzureOpenAIGPT5Config):
|
||||
"""Test that Azure GPT-5 (non-5.1) raises error for reasoning_effort='none' when drop_params=False."""
|
||||
with pytest.raises(litellm.utils.UnsupportedParamsError):
|
||||
|
||||
@ -43,6 +43,29 @@ def test_transform_usage():
|
||||
)
|
||||
assert openai_usage._cache_creation_input_tokens == usage["cacheWriteInputTokens"]
|
||||
assert openai_usage._cache_read_input_tokens == usage["cacheReadInputTokens"]
|
||||
# completion_tokens_details should always be populated
|
||||
assert openai_usage.completion_tokens_details is not None
|
||||
assert openai_usage.completion_tokens_details.reasoning_tokens == 0
|
||||
assert openai_usage.completion_tokens_details.text_tokens == usage["outputTokens"]
|
||||
|
||||
|
||||
def test_transform_usage_with_reasoning_content():
|
||||
"""Test that completion_tokens_details correctly tracks reasoning vs text tokens."""
|
||||
usage = ConverseTokenUsageBlock(
|
||||
**{
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 100,
|
||||
"totalTokens": 110,
|
||||
}
|
||||
)
|
||||
config = AmazonConverseConfig()
|
||||
reasoning_text = "Let me think about this step by step."
|
||||
openai_usage = config._transform_usage(usage, reasoning_content=reasoning_text)
|
||||
assert openai_usage.completion_tokens_details is not None
|
||||
assert openai_usage.completion_tokens_details.reasoning_tokens > 0
|
||||
assert openai_usage.completion_tokens_details.text_tokens == (
|
||||
usage["outputTokens"] - openai_usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
|
||||
|
||||
def test_transform_system_message():
|
||||
@ -3170,6 +3193,33 @@ def test_transform_request_with_output_config():
|
||||
assert result["outputConfig"]["textFormat"]["structure"]["jsonSchema"]["name"] == "TestSchema"
|
||||
|
||||
|
||||
def test_output_config_snake_case_stripped_from_bedrock_converse_request():
|
||||
"""Test that output_config (snake_case) is stripped from Bedrock Converse requests.
|
||||
|
||||
Bedrock Converse API doesn't support the output_config parameter (Anthropic-only).
|
||||
Nova and other Converse models reject requests with extraneous output_config.
|
||||
"""
|
||||
config = AmazonConverseConfig()
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
optional_params = {
|
||||
"output_config": {"effort": "high"},
|
||||
}
|
||||
|
||||
result = config._transform_request(
|
||||
model="us.amazon.nova-pro-v1:0",
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
headers={},
|
||||
)
|
||||
|
||||
# output_config must not appear in additionalModelRequestFields
|
||||
additional = result.get("additionalModelRequestFields", {})
|
||||
assert "output_config" not in additional, (
|
||||
f"output_config should be stripped for Bedrock Converse, got: {list(additional.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def test_transform_response_native_structured_output():
|
||||
"""Test response handling when model returns JSON as text content (native structured output)."""
|
||||
response_json = {
|
||||
|
||||
@ -110,6 +110,60 @@ def test_get_supported_openai_params_reasoning_effort():
|
||||
assert "reasoning_effort" not in unsupported_params
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_base, expected_url_prefix",
|
||||
[
|
||||
(
|
||||
"https://api.fireworks.ai/inference/v1",
|
||||
"https://api.fireworks.ai/inference/v1/accounts/",
|
||||
),
|
||||
(
|
||||
"https://api.fireworks.ai/inference/v1/",
|
||||
"https://api.fireworks.ai/inference/v1/accounts/",
|
||||
),
|
||||
(
|
||||
"https://custom-host.example.com/v1",
|
||||
"https://custom-host.example.com/v1/accounts/",
|
||||
),
|
||||
(
|
||||
"https://custom-host.example.com/api",
|
||||
"https://custom-host.example.com/api/v1/accounts/",
|
||||
),
|
||||
],
|
||||
ids=["default", "trailing-slash", "custom-with-v1", "custom-without-v1"],
|
||||
)
|
||||
def test_get_models_url_no_double_v1(api_base, expected_url_prefix):
|
||||
"""Ensure get_models never produces a /v1/v1/ URL segment (fixes #23106)."""
|
||||
config = FireworksAIConfig()
|
||||
account_id = "fireworks"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"models": [{"name": "accounts/fireworks/models/llama-v3-70b"}]
|
||||
}
|
||||
|
||||
with (
|
||||
patch("litellm.module_level_client.get", return_value=mock_response) as mock_get,
|
||||
patch(
|
||||
"litellm.llms.fireworks_ai.chat.transformation.get_secret_str",
|
||||
side_effect=lambda key: {
|
||||
"FIREWORKS_API_KEY": "test-key",
|
||||
"FIREWORKS_API_BASE": api_base,
|
||||
"FIREWORKS_ACCOUNT_ID": account_id,
|
||||
}.get(key),
|
||||
),
|
||||
):
|
||||
result = config.get_models(api_key="test-key", api_base=api_base)
|
||||
|
||||
called_url = mock_get.call_args.kwargs.get("url") or mock_get.call_args[1].get("url", "")
|
||||
assert "/v1/v1/" not in called_url, f"Double /v1/ detected in URL: {called_url}"
|
||||
assert called_url.startswith(expected_url_prefix), (
|
||||
f"URL {called_url} does not start with {expected_url_prefix}"
|
||||
)
|
||||
assert result == ["fireworks_ai/accounts/fireworks/models/llama-v3-70b"]
|
||||
|
||||
|
||||
def test_transform_messages_helper_removes_provider_specific_fields():
|
||||
"""
|
||||
Test that _transform_messages_helper removes provider_specific_fields from messages.
|
||||
|
||||
@ -13,6 +13,7 @@ from litellm.llms.openai.chat.gpt_transformation import (
|
||||
OpenAIChatCompletionStreamingHandler,
|
||||
OpenAIGPTConfig,
|
||||
)
|
||||
from litellm.llms.openai.chat.gpt_5_transformation import OpenAIGPT5Config
|
||||
|
||||
|
||||
class TestOpenAIGPTConfig:
|
||||
@ -324,3 +325,195 @@ class TestPromptCacheParams:
|
||||
)
|
||||
assert optional_params.get("prompt_cache_key") == "my-cache-key"
|
||||
assert optional_params.get("prompt_cache_retention") == "24h"
|
||||
|
||||
|
||||
class TestGPT5ReasoningEffortPreservation:
|
||||
"""Tests for GPT-5 reasoning_effort dict preservation for Responses API."""
|
||||
|
||||
def setup_method(self):
|
||||
self.config = OpenAIGPT5Config()
|
||||
|
||||
def test_reasoning_effort_string_preserved(self):
|
||||
"""Test that reasoning_effort as string is preserved."""
|
||||
non_default_params = {"reasoning_effort": "high"}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
# String format should be preserved
|
||||
assert non_default_params.get("reasoning_effort") == "high"
|
||||
|
||||
def test_reasoning_effort_dict_with_only_effort_normalized(self):
|
||||
"""Test that reasoning_effort dict with only 'effort' key is normalized to string."""
|
||||
non_default_params = {"reasoning_effort": {"effort": "high"}}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
# Dict with only 'effort' should be normalized to string
|
||||
assert non_default_params.get("reasoning_effort") == "high"
|
||||
|
||||
def test_reasoning_effort_dict_with_summary_preserved(self):
|
||||
"""Test that reasoning_effort dict with 'summary' field is preserved for Responses API.
|
||||
|
||||
Regression test for: User reported that summary field was being dropped when
|
||||
routing to Responses API. The dict format with additional fields should be
|
||||
preserved so it can be properly handled by the Responses API transformation.
|
||||
"""
|
||||
non_default_params = {"reasoning_effort": {"effort": "high", "summary": "detailed"}}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
# Dict with additional fields should be preserved
|
||||
assert non_default_params.get("reasoning_effort") == {"effort": "high", "summary": "detailed"}
|
||||
assert isinstance(non_default_params.get("reasoning_effort"), dict)
|
||||
assert non_default_params["reasoning_effort"]["effort"] == "high"
|
||||
assert non_default_params["reasoning_effort"]["summary"] == "detailed"
|
||||
|
||||
def test_reasoning_effort_dict_with_generate_summary_preserved(self):
|
||||
"""Test that reasoning_effort dict with 'generate_summary' field is preserved."""
|
||||
non_default_params = {"reasoning_effort": {"effort": "medium", "generate_summary": "auto"}}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
# Dict with additional fields should be preserved
|
||||
assert non_default_params.get("reasoning_effort") == {"effort": "medium", "generate_summary": "auto"}
|
||||
assert isinstance(non_default_params.get("reasoning_effort"), dict)
|
||||
|
||||
def test_reasoning_effort_dict_with_all_fields_preserved(self):
|
||||
"""Test that reasoning_effort dict with all fields is preserved."""
|
||||
non_default_params = {
|
||||
"reasoning_effort": {
|
||||
"effort": "high",
|
||||
"summary": "detailed",
|
||||
"generate_summary": "concise"
|
||||
}
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
# Dict with all fields should be preserved
|
||||
reasoning = non_default_params.get("reasoning_effort")
|
||||
assert isinstance(reasoning, dict)
|
||||
assert reasoning["effort"] == "high"
|
||||
assert reasoning["summary"] == "detailed"
|
||||
assert reasoning["generate_summary"] == "concise"
|
||||
|
||||
def test_reasoning_effort_dict_xhigh_triggers_validation(self):
|
||||
"""xhigh-dict: effective effort is extracted for model-support validation.
|
||||
|
||||
When reasoning_effort={"effort": "xhigh", "summary": "detailed"} is passed to a model
|
||||
that doesn't support xhigh (e.g. gpt-5.1), the xhigh guard must fire.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
non_default_params = {"reasoning_effort": {"effort": "xhigh", "summary": "detailed"}}
|
||||
optional_params = {}
|
||||
|
||||
with pytest.raises(litellm.utils.UnsupportedParamsError):
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.1",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
def test_reasoning_effort_dict_xhigh_dropped_when_requested(self):
|
||||
"""xhigh-dict with drop_params=True: reasoning_effort is dropped."""
|
||||
non_default_params = {"reasoning_effort": {"effort": "xhigh", "summary": "detailed"}}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.1",
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
assert "reasoning_effort" not in non_default_params
|
||||
|
||||
def test_reasoning_effort_dict_none_treated_as_none_for_tools(self):
|
||||
"""none-dict: {"effort": "none", "summary": "detailed"} is treated as effort=none.
|
||||
|
||||
Tool-drop guard should NOT fire; reasoning_effort should be kept.
|
||||
"""
|
||||
tools = [{"type": "function", "function": {"name": "test", "description": "test"}}]
|
||||
non_default_params = {"reasoning_effort": {"effort": "none", "summary": "detailed"}, "tools": tools}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert non_default_params.get("reasoning_effort") == {"effort": "none", "summary": "detailed"}
|
||||
assert non_default_params.get("tools") == tools
|
||||
|
||||
def test_reasoning_effort_dict_none_treated_as_none_for_sampling(self):
|
||||
"""none-dict: {"effort": "none", "summary": "detailed"} allows logprobs/top_p.
|
||||
|
||||
Sampling-param guard should NOT fire; logprobs should be kept.
|
||||
"""
|
||||
non_default_params = {
|
||||
"reasoning_effort": {"effort": "none", "summary": "detailed"},
|
||||
"logprobs": True,
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.1",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert non_default_params.get("reasoning_effort") == {"effort": "none", "summary": "detailed"}
|
||||
assert non_default_params.get("logprobs") is True
|
||||
|
||||
def test_reasoning_effort_dict_none_allows_temperature(self):
|
||||
"""none-dict: {"effort": "none", "summary": "detailed"} allows non-default temperature."""
|
||||
non_default_params = {
|
||||
"reasoning_effort": {"effort": "none", "summary": "detailed"},
|
||||
"temperature": 0.5,
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
self.config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gpt-5.1",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert optional_params.get("temperature") == 0.5
|
||||
assert non_default_params.get("reasoning_effort") == {"effort": "none", "summary": "detailed"}
|
||||
|
||||
@ -324,10 +324,11 @@ def test_gpt5_4_pro_allows_reasoning_effort_xhigh(config: OpenAIConfig):
|
||||
assert params["reasoning_effort"] == "xhigh"
|
||||
|
||||
|
||||
def test_gpt5_normalizes_reasoning_effort_dict_to_string(config: OpenAIConfig):
|
||||
"""Chat completion API expects reasoning_effort as a string, not a dict.
|
||||
def test_gpt5_preserves_reasoning_effort_dict_with_summary(config: OpenAIConfig):
|
||||
"""Dict with summary/generate_summary is preserved for Responses API.
|
||||
|
||||
Config/deployments may pass Responses API format: {'effort': 'high', 'summary': 'detailed'}.
|
||||
We preserve the full dict so it reaches the Responses API transformation.
|
||||
"""
|
||||
params = config.map_openai_params(
|
||||
non_default_params={"reasoning_effort": {"effort": "high", "summary": "detailed"}},
|
||||
@ -335,18 +336,82 @@ def test_gpt5_normalizes_reasoning_effort_dict_to_string(config: OpenAIConfig):
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
assert params["reasoning_effort"] == "high"
|
||||
assert params["reasoning_effort"] == {"effort": "high", "summary": "detailed"}
|
||||
|
||||
|
||||
def test_gpt5_normalizes_reasoning_effort_dict_from_optional_params(config: OpenAIConfig):
|
||||
"""reasoning_effort dict in optional_params (e.g. from model config) is normalized."""
|
||||
def test_gpt5_xhigh_dict_triggers_validation(config: OpenAIConfig):
|
||||
"""Dict with effort='xhigh' triggers xhigh model-support validation.
|
||||
|
||||
Regression: when reasoning_effort is a dict, effective_effort must be used for
|
||||
the xhigh guard so validation is not silently skipped.
|
||||
"""
|
||||
with pytest.raises(litellm.utils.UnsupportedParamsError):
|
||||
config.map_openai_params(
|
||||
non_default_params={"reasoning_effort": {"effort": "xhigh", "summary": "detailed"}},
|
||||
optional_params={},
|
||||
model="gpt-5.1",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
|
||||
def test_gpt5_xhigh_dict_accepted_for_supported_model(config: OpenAIConfig):
|
||||
"""Dict with effort='xhigh' passes through for gpt-5.4+."""
|
||||
params = config.map_openai_params(
|
||||
non_default_params={"reasoning_effort": {"effort": "xhigh", "summary": "detailed"}},
|
||||
optional_params={},
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
assert params["reasoning_effort"] == {"effort": "xhigh", "summary": "detailed"}
|
||||
|
||||
|
||||
def test_gpt5_none_dict_with_tools_no_tool_drop(config: OpenAIConfig):
|
||||
"""Dict with effort='none' and tools: no tool-drop, reasoning_effort preserved.
|
||||
|
||||
Regression: effective_effort='none' must be used for tool-drop guard so
|
||||
{"effort": "none", "summary": "detailed"} is not incorrectly treated as non-none.
|
||||
"""
|
||||
tools = [{"type": "function", "function": {"name": "test", "description": "test"}}]
|
||||
params = config.map_openai_params(
|
||||
non_default_params={"reasoning_effort": {"effort": "none", "summary": "detailed"}, "tools": tools},
|
||||
optional_params={},
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
assert params["reasoning_effort"] == {"effort": "none", "summary": "detailed"}
|
||||
assert params["tools"] == tools
|
||||
|
||||
|
||||
def test_gpt5_none_dict_with_sampling_params_allowed(config: OpenAIConfig):
|
||||
"""Dict with effort='none' allows logprobs/top_p/top_logprobs.
|
||||
|
||||
Regression: effective_effort='none' must be used for sampling guard so
|
||||
{"effort": "none", "summary": "detailed"} does not incorrectly trigger sampling errors.
|
||||
"""
|
||||
params = config.map_openai_params(
|
||||
non_default_params={
|
||||
"reasoning_effort": {"effort": "none", "summary": "detailed"},
|
||||
"logprobs": True,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
optional_params={},
|
||||
model="gpt-5.1",
|
||||
drop_params=False,
|
||||
)
|
||||
assert params["reasoning_effort"] == {"effort": "none", "summary": "detailed"}
|
||||
assert params["logprobs"] is True
|
||||
assert params["top_p"] == 0.9
|
||||
|
||||
|
||||
def test_gpt5_preserves_reasoning_effort_dict_with_summary_from_optional_params(config: OpenAIConfig):
|
||||
"""reasoning_effort dict with summary in optional_params is preserved."""
|
||||
params = config.map_openai_params(
|
||||
non_default_params={},
|
||||
optional_params={"reasoning_effort": {"effort": "medium", "summary": "detailed"}},
|
||||
model="gpt-5.4",
|
||||
drop_params=False,
|
||||
)
|
||||
assert params["reasoning_effort"] == "medium"
|
||||
assert params["reasoning_effort"] == {"effort": "medium", "summary": "detailed"}
|
||||
|
||||
|
||||
def test_gpt5_4_drops_reasoning_effort_when_tools_present(config: OpenAIConfig):
|
||||
|
||||
@ -0,0 +1,243 @@
|
||||
"""
|
||||
Test cases for SageMaker embedding role assumption support
|
||||
|
||||
This module tests that the SageMaker embedding handler properly supports
|
||||
AWS IAM role assumption via aws_role_name and aws_session_name parameters,
|
||||
matching the behavior of the completion handler.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../../../.."))
|
||||
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
from litellm.llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
|
||||
class TestSagemakerEmbeddingRoleAssumption:
|
||||
"""Test that SageMaker embedding supports role assumption like completion does"""
|
||||
|
||||
def setup_method(self):
|
||||
self.sagemaker_llm = SagemakerLLM()
|
||||
|
||||
def test_embedding_uses_load_credentials(self):
|
||||
"""
|
||||
Test that embedding() calls _load_credentials() to support role assumption.
|
||||
This ensures aws_role_name and aws_session_name parameters are properly handled.
|
||||
"""
|
||||
# Mock credentials that would be returned after role assumption
|
||||
mock_credentials = Credentials(
|
||||
access_key="assumed-access-key",
|
||||
secret_key="assumed-secret-key",
|
||||
token="assumed-session-token",
|
||||
)
|
||||
|
||||
# Mock the SageMaker client response
|
||||
mock_sagemaker_client = MagicMock()
|
||||
mock_sagemaker_client.invoke_endpoint.return_value = {
|
||||
"Body": MagicMock(
|
||||
read=MagicMock(return_value=json.dumps({"embedding": [[0.1, 0.2, 0.3]]}).encode())
|
||||
)
|
||||
}
|
||||
|
||||
# Mock boto3.Session to return our mock client
|
||||
mock_session = MagicMock()
|
||||
mock_session.client.return_value = mock_sagemaker_client
|
||||
|
||||
with patch.object(
|
||||
self.sagemaker_llm, "_load_credentials", return_value=(mock_credentials, "us-east-1")
|
||||
) as mock_load_creds, patch("boto3.Session", return_value=mock_session):
|
||||
|
||||
# Create mock logging object
|
||||
mock_logging = MagicMock()
|
||||
|
||||
optional_params = {
|
||||
"aws_role_name": "arn:aws:iam::123456789012:role/TestRole",
|
||||
"aws_session_name": "test-session",
|
||||
}
|
||||
|
||||
self.sagemaker_llm.embedding(
|
||||
model="test-endpoint",
|
||||
input=["hello world"],
|
||||
model_response=EmbeddingResponse(),
|
||||
print_verbose=print,
|
||||
encoding=None,
|
||||
logging_obj=mock_logging,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
# Verify _load_credentials was called with the optional_params
|
||||
mock_load_creds.assert_called_once()
|
||||
|
||||
# Verify boto3.Session was created with the assumed credentials
|
||||
mock_session_calls = mock_session.client.call_args_list
|
||||
assert len(mock_session_calls) == 1
|
||||
assert mock_session_calls[0] == call(service_name="sagemaker-runtime")
|
||||
|
||||
def test_embedding_role_assumption_with_sts(self):
|
||||
"""
|
||||
Test the full role assumption flow for embeddings, similar to completion.
|
||||
Verifies that STS assume_role is called when aws_role_name is provided.
|
||||
"""
|
||||
# Mock the STS client for role assumption
|
||||
mock_sts_client = MagicMock()
|
||||
|
||||
# Mock the STS response with proper expiration handling
|
||||
mock_expiry = MagicMock()
|
||||
mock_expiry.tzinfo = timezone.utc
|
||||
time_diff = MagicMock()
|
||||
time_diff.total_seconds.return_value = 3600
|
||||
mock_expiry.__sub__ = MagicMock(return_value=time_diff)
|
||||
|
||||
mock_sts_response = {
|
||||
"Credentials": {
|
||||
"AccessKeyId": "assumed-access-key",
|
||||
"SecretAccessKey": "assumed-secret-key",
|
||||
"SessionToken": "assumed-session-token",
|
||||
"Expiration": mock_expiry,
|
||||
}
|
||||
}
|
||||
mock_sts_client.assume_role.return_value = mock_sts_response
|
||||
|
||||
# Mock the SageMaker client response
|
||||
mock_sagemaker_client = MagicMock()
|
||||
mock_sagemaker_client.invoke_endpoint.return_value = {
|
||||
"Body": MagicMock(
|
||||
read=MagicMock(return_value=json.dumps({"embedding": [[0.1, 0.2, 0.3]]}).encode())
|
||||
)
|
||||
}
|
||||
|
||||
# Mock boto3.Session for SageMaker client creation
|
||||
mock_session = MagicMock()
|
||||
mock_session.client.return_value = mock_sagemaker_client
|
||||
|
||||
def mock_boto3_client(service_name, **kwargs):
|
||||
if service_name == "sts":
|
||||
return mock_sts_client
|
||||
return mock_sagemaker_client
|
||||
|
||||
with patch("boto3.client", side_effect=mock_boto3_client), \
|
||||
patch("boto3.Session", return_value=mock_session):
|
||||
|
||||
mock_logging = MagicMock()
|
||||
|
||||
optional_params = {
|
||||
"aws_role_name": "arn:aws:iam::123456789012:role/CrossAccountRole",
|
||||
"aws_session_name": "litellm-embedding-session",
|
||||
"aws_region_name": "us-east-1",
|
||||
}
|
||||
|
||||
self.sagemaker_llm.embedding(
|
||||
model="test-endpoint",
|
||||
input=["hello world"],
|
||||
model_response=EmbeddingResponse(),
|
||||
print_verbose=print,
|
||||
encoding=None,
|
||||
logging_obj=mock_logging,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
# Verify STS assume_role was called with correct parameters
|
||||
mock_sts_client.assume_role.assert_called_once()
|
||||
call_args = mock_sts_client.assume_role.call_args
|
||||
assert call_args[1]["RoleArn"] == "arn:aws:iam::123456789012:role/CrossAccountRole"
|
||||
assert call_args[1]["RoleSessionName"] == "litellm-embedding-session"
|
||||
|
||||
def test_embedding_without_role_assumption(self):
|
||||
"""
|
||||
Test that embedding works without role assumption when aws_role_name is not provided.
|
||||
Should use default credentials from environment/instance profile.
|
||||
"""
|
||||
# Mock the SageMaker client response
|
||||
mock_sagemaker_client = MagicMock()
|
||||
mock_sagemaker_client.invoke_endpoint.return_value = {
|
||||
"Body": MagicMock(
|
||||
read=MagicMock(return_value=json.dumps({"embedding": [[0.1, 0.2, 0.3]]}).encode())
|
||||
)
|
||||
}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.client.return_value = mock_sagemaker_client
|
||||
|
||||
# Mock credentials returned from environment
|
||||
mock_credentials = Credentials(
|
||||
access_key="env-access-key",
|
||||
secret_key="env-secret-key",
|
||||
token=None,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
self.sagemaker_llm, "_load_credentials", return_value=(mock_credentials, "us-west-2")
|
||||
), patch("boto3.Session", return_value=mock_session):
|
||||
|
||||
mock_logging = MagicMock()
|
||||
|
||||
# No aws_role_name provided
|
||||
optional_params = {
|
||||
"aws_region_name": "us-west-2",
|
||||
}
|
||||
|
||||
result = self.sagemaker_llm.embedding(
|
||||
model="test-endpoint",
|
||||
input=["hello world"],
|
||||
model_response=EmbeddingResponse(),
|
||||
print_verbose=print,
|
||||
encoding=None,
|
||||
logging_obj=mock_logging,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
# Should still work and return embeddings
|
||||
assert result is not None
|
||||
|
||||
def test_embedding_session_created_with_assumed_credentials(self):
|
||||
"""
|
||||
Test that boto3.Session is created with the credentials from role assumption.
|
||||
This verifies the credentials flow from _load_credentials to the SageMaker client.
|
||||
"""
|
||||
mock_credentials = Credentials(
|
||||
access_key="assumed-key",
|
||||
secret_key="assumed-secret",
|
||||
token="assumed-token",
|
||||
)
|
||||
|
||||
mock_sagemaker_client = MagicMock()
|
||||
mock_sagemaker_client.invoke_endpoint.return_value = {
|
||||
"Body": MagicMock(
|
||||
read=MagicMock(return_value=json.dumps({"embedding": [[0.1, 0.2, 0.3]]}).encode())
|
||||
)
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
self.sagemaker_llm, "_load_credentials", return_value=(mock_credentials, "us-east-1")
|
||||
), patch("boto3.Session") as mock_session_class:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.client.return_value = mock_sagemaker_client
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_logging = MagicMock()
|
||||
|
||||
self.sagemaker_llm.embedding(
|
||||
model="test-endpoint",
|
||||
input=["hello world"],
|
||||
model_response=EmbeddingResponse(),
|
||||
print_verbose=print,
|
||||
encoding=None,
|
||||
logging_obj=mock_logging,
|
||||
optional_params={},
|
||||
)
|
||||
|
||||
# Verify Session was created with the assumed credentials
|
||||
mock_session_class.assert_called_once_with(
|
||||
aws_access_key_id="assumed-key",
|
||||
aws_secret_access_key="assumed-secret",
|
||||
aws_session_token="assumed-token",
|
||||
region_name="us-east-1",
|
||||
)
|
||||
@ -128,6 +128,75 @@ def test_vertex_ai_includes_labels():
|
||||
|
||||
|
||||
|
||||
def test_extra_body_cache_not_forwarded_to_vertex_ai():
|
||||
"""
|
||||
'cache' inside extra_body is a LiteLLM-internal proxy caching control.
|
||||
It must NOT be forwarded to the Vertex AI request body.
|
||||
|
||||
Regression test for: "Invalid JSON payload received. Unknown name \"cache\": Cannot find field."
|
||||
Vertex AI enforces a strict JSON schema and rejects any unknown field.
|
||||
"""
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
optional_params = {
|
||||
"extra_body": {
|
||||
"cache": {"use-cache": True, "ttl": 86400}, # LiteLLM-internal
|
||||
"some_vertex_param": "value", # legitimate provider extra
|
||||
},
|
||||
}
|
||||
litellm_params = {}
|
||||
|
||||
result = _transform_request_body(
|
||||
messages=messages,
|
||||
model="gemini-2.5-pro",
|
||||
optional_params=optional_params,
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params=litellm_params,
|
||||
cached_content=None,
|
||||
)
|
||||
|
||||
# 'cache' must be stripped — Vertex AI has no such field
|
||||
assert "cache" not in result, (
|
||||
"extra_body.cache must not be forwarded to Vertex AI. "
|
||||
"Vertex AI rejects it with 400: Unknown name \"cache\": Cannot find field."
|
||||
)
|
||||
|
||||
# Other legitimate extra_body keys should still pass through
|
||||
assert "some_vertex_param" in result
|
||||
assert result["some_vertex_param"] == "value"
|
||||
|
||||
# Core request fields must be present
|
||||
assert "contents" in result
|
||||
|
||||
|
||||
def test_extra_body_tags_not_forwarded_to_vertex_ai():
|
||||
"""
|
||||
'tags' inside extra_body is a LiteLLM-internal param for logging/tracking.
|
||||
It must NOT be forwarded to the Vertex AI request body.
|
||||
Documented in litellm_proxy.md: "Send tags by including them in the extra_body parameter"
|
||||
"""
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
optional_params = {
|
||||
"extra_body": {
|
||||
"tags": ["user:alice", "env:prod"],
|
||||
"custom_param": "allowed",
|
||||
},
|
||||
}
|
||||
litellm_params = {}
|
||||
|
||||
result = _transform_request_body(
|
||||
messages=messages,
|
||||
model="gemini-2.5-pro",
|
||||
optional_params=optional_params,
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params=litellm_params,
|
||||
cached_content=None,
|
||||
)
|
||||
|
||||
assert "tags" not in result
|
||||
assert "custom_param" in result
|
||||
assert result["custom_param"] == "allowed"
|
||||
|
||||
|
||||
def test_metadata_to_labels_vertex_only():
|
||||
"""Test that metadata->labels conversion only happens for Vertex AI"""
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
|
||||
@ -2093,3 +2093,150 @@ async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enab
|
||||
assert spend_meta["tool_count_total"] == 1
|
||||
assert spend_meta["allowed_server_count"] == 1
|
||||
assert spend_meta["per_server_tool_counts"]["server_a"] == 1
|
||||
|
||||
|
||||
def test_tool_name_matches_case_insensitive():
|
||||
"""Test that _tool_name_matches performs case-insensitive comparison.
|
||||
|
||||
This is critical for OpenAPI-based MCP servers where:
|
||||
1. operationIds are often in camelCase (e.g., 'addPet', 'updatePet')
|
||||
2. Tool names are lowercased during registration (e.g., 'addpet', 'updatepet')
|
||||
3. allowed_tools configuration may use the original camelCase names
|
||||
|
||||
Without case-insensitive matching, all tools would be filtered out.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.server import _tool_name_matches
|
||||
except ImportError:
|
||||
pytest.skip("MCP server not available")
|
||||
|
||||
# Test case 1: Unprefixed tool name with camelCase in filter list
|
||||
assert _tool_name_matches("addpet", ["addPet", "updatePet"]) is True
|
||||
assert _tool_name_matches("updatepet", ["addPet", "updatePet"]) is True
|
||||
assert _tool_name_matches("deletepet", ["addPet", "updatePet"]) is False
|
||||
|
||||
# Test case 2: Prefixed tool name with camelCase in filter list
|
||||
assert _tool_name_matches("per_store-addpet", ["addPet", "updatePet"]) is True
|
||||
assert _tool_name_matches("per_store-updatepet", ["addPet", "updatePet"]) is True
|
||||
assert _tool_name_matches("per_store-deletepet", ["addPet", "updatePet"]) is False
|
||||
|
||||
# Test case 3: Mixed case variations
|
||||
assert _tool_name_matches("findPetsByStatus", ["findpetsbystatus"]) is True
|
||||
assert _tool_name_matches("findpetsbystatus", ["findPetsByStatus"]) is True
|
||||
assert _tool_name_matches("FINDPETSBYSTATUS", ["findPetsByStatus"]) is True
|
||||
|
||||
# Test case 4: Full prefixed name in filter list (case-insensitive)
|
||||
assert _tool_name_matches("server-addPet", ["server-addpet"]) is True
|
||||
assert _tool_name_matches("server-addpet", ["server-addPet"]) is True
|
||||
|
||||
# Test case 5: Ensure non-matching names still don't match
|
||||
assert _tool_name_matches("addpet", ["deletePet", "updatePet"]) is False
|
||||
assert _tool_name_matches("server-addpet", ["deletePet", "updatePet"]) is False
|
||||
|
||||
|
||||
def test_filter_tools_by_allowed_tools_case_insensitive():
|
||||
"""Test that filter_tools_by_allowed_tools handles case-insensitive matching.
|
||||
|
||||
Ensures that OpenAPI tools with lowercase names can be filtered using
|
||||
camelCase allowed_tools configuration from the OpenAPI spec.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.server import (
|
||||
filter_tools_by_allowed_tools,
|
||||
)
|
||||
from litellm.types.mcp_server.tool_registry import MCPTool
|
||||
except ImportError:
|
||||
pytest.skip("MCP server not available")
|
||||
|
||||
# Mock handler function
|
||||
def mock_handler(**kwargs):
|
||||
return kwargs
|
||||
|
||||
# Create mock tools with lowercase names (as registered from OpenAPI)
|
||||
tools = [
|
||||
MCPTool(
|
||||
name="per_store-addpet",
|
||||
description="Add a pet",
|
||||
input_schema={"type": "object"},
|
||||
handler=mock_handler,
|
||||
),
|
||||
MCPTool(
|
||||
name="per_store-updatepet",
|
||||
description="Update a pet",
|
||||
input_schema={"type": "object"},
|
||||
handler=mock_handler,
|
||||
),
|
||||
MCPTool(
|
||||
name="per_store-deletepet",
|
||||
description="Delete a pet",
|
||||
input_schema={"type": "object"},
|
||||
handler=mock_handler,
|
||||
),
|
||||
MCPTool(
|
||||
name="per_store-findpetsbystatus",
|
||||
description="Find pets by status",
|
||||
input_schema={"type": "object"},
|
||||
handler=mock_handler,
|
||||
),
|
||||
]
|
||||
|
||||
# Create mock server with camelCase allowed_tools (as from OpenAPI spec)
|
||||
server = MCPServer(
|
||||
server_id="test-server",
|
||||
name="per_store",
|
||||
transport=MCPTransport.http,
|
||||
allowed_tools=["addPet", "updatePet", "findPetsByStatus"],
|
||||
)
|
||||
|
||||
# Filter tools
|
||||
filtered_tools = filter_tools_by_allowed_tools(tools, server)
|
||||
|
||||
# Should return 3 tools (case-insensitive match)
|
||||
assert len(filtered_tools) == 3
|
||||
assert any(t.name == "per_store-addpet" for t in filtered_tools)
|
||||
assert any(t.name == "per_store-updatepet" for t in filtered_tools)
|
||||
assert any(t.name == "per_store-findpetsbystatus" for t in filtered_tools)
|
||||
assert not any(t.name == "per_store-deletepet" for t in filtered_tools)
|
||||
|
||||
|
||||
def test_filter_tools_by_allowed_tools_no_filter():
|
||||
"""Test that filter_tools_by_allowed_tools returns all tools when no filter is set."""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.server import (
|
||||
filter_tools_by_allowed_tools,
|
||||
)
|
||||
from litellm.types.mcp_server.tool_registry import MCPTool
|
||||
except ImportError:
|
||||
pytest.skip("MCP server not available")
|
||||
|
||||
# Mock handler function
|
||||
def mock_handler(**kwargs):
|
||||
return kwargs
|
||||
|
||||
tools = [
|
||||
MCPTool(
|
||||
name="fusion_litellm_mcp-model_list",
|
||||
description="List models",
|
||||
input_schema={"type": "object"},
|
||||
handler=mock_handler,
|
||||
),
|
||||
MCPTool(
|
||||
name="fusion_litellm_mcp-chat_completion",
|
||||
description="Chat completion",
|
||||
input_schema={"type": "object"},
|
||||
handler=mock_handler,
|
||||
),
|
||||
]
|
||||
|
||||
# Server with no allowed_tools filter
|
||||
server = MCPServer(
|
||||
server_id="test-server",
|
||||
name="fusion_litellm_mcp",
|
||||
transport=MCPTransport.http,
|
||||
allowed_tools=None,
|
||||
)
|
||||
|
||||
filtered_tools = filter_tools_by_allowed_tools(tools, server)
|
||||
|
||||
# Should return all tools when no filter is configured
|
||||
assert len(filtered_tools) == 2
|
||||
|
||||
@ -21,6 +21,140 @@ def test_get_team_models_for_all_models_and_team_only_models():
|
||||
assert set(result) == set(combined_models)
|
||||
|
||||
|
||||
def test_get_team_models_all_proxy_models_includes_access_groups():
|
||||
"""
|
||||
When a team has 'all-proxy-models' and include_model_access_groups=True,
|
||||
the result should include model access group names (e.g. 'claude-model-group')
|
||||
in addition to individual model names.
|
||||
"""
|
||||
from litellm.proxy.auth.model_checks import get_team_models
|
||||
|
||||
team_models = ["all-proxy-models"]
|
||||
proxy_model_list = ["model1", "model2"]
|
||||
model_access_groups = {
|
||||
"group-a": ["model1"],
|
||||
"group-b": ["model2"],
|
||||
}
|
||||
|
||||
result = get_team_models(
|
||||
team_models, proxy_model_list, model_access_groups, include_model_access_groups=True
|
||||
)
|
||||
assert "group-a" in result
|
||||
assert "group-b" in result
|
||||
assert "model1" in result
|
||||
assert "model2" in result
|
||||
assert len(result) == len(set(result)), "result should have no duplicates"
|
||||
|
||||
|
||||
def test_get_team_models_all_proxy_models_without_include_flag():
|
||||
"""
|
||||
When include_model_access_groups=False, access group names should NOT
|
||||
appear in the result even with 'all-proxy-models'.
|
||||
"""
|
||||
from litellm.proxy.auth.model_checks import get_team_models
|
||||
|
||||
team_models = ["all-proxy-models"]
|
||||
proxy_model_list = ["model1", "model2"]
|
||||
model_access_groups = {
|
||||
"group-a": ["model1"],
|
||||
"group-b": ["model2"],
|
||||
}
|
||||
|
||||
result = get_team_models(
|
||||
team_models, proxy_model_list, model_access_groups, include_model_access_groups=False
|
||||
)
|
||||
assert "group-a" not in result
|
||||
assert "group-b" not in result
|
||||
assert "model1" in result
|
||||
assert "model2" in result
|
||||
|
||||
|
||||
def test_get_key_models_all_proxy_models_includes_access_groups():
|
||||
"""
|
||||
When a key has 'all-proxy-models' and include_model_access_groups=True,
|
||||
the result should include model access group names.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.model_checks import get_key_models
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
models=["all-proxy-models"],
|
||||
api_key="test-key",
|
||||
)
|
||||
proxy_model_list = ["model1", "model2"]
|
||||
model_access_groups = {
|
||||
"group-a": ["model1"],
|
||||
}
|
||||
|
||||
result = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_model_list=proxy_model_list,
|
||||
model_access_groups=model_access_groups,
|
||||
include_model_access_groups=True,
|
||||
)
|
||||
assert "group-a" in result
|
||||
assert "model1" in result
|
||||
assert "model2" in result
|
||||
assert len(result) == len(set(result)), "result should have no duplicates"
|
||||
|
||||
|
||||
def test_get_key_models_passes_include_model_access_groups():
|
||||
"""
|
||||
When a key explicitly has an access group name in its models list and
|
||||
include_model_access_groups=True, the group name should be retained
|
||||
(not stripped by _get_models_from_access_groups).
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.model_checks import get_key_models
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
models=["group-a"],
|
||||
api_key="test-key",
|
||||
)
|
||||
proxy_model_list = ["model1", "model2"]
|
||||
model_access_groups = {
|
||||
"group-a": ["model1", "model2"],
|
||||
}
|
||||
|
||||
result = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_model_list=proxy_model_list,
|
||||
model_access_groups=model_access_groups,
|
||||
include_model_access_groups=True,
|
||||
)
|
||||
assert "group-a" in result
|
||||
assert "model1" in result
|
||||
assert "model2" in result
|
||||
|
||||
|
||||
def test_get_key_models_does_not_mutate_input():
|
||||
"""
|
||||
get_key_models must not mutate user_api_key_dict.models in-place.
|
||||
_get_models_from_access_groups uses .pop()/.extend() which would corrupt
|
||||
cached UserAPIKeyAuth objects if all_models were an alias instead of a copy.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.model_checks import get_key_models
|
||||
|
||||
original_models = ["group-a", "extra-model"]
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
models=list(original_models), # give it a list
|
||||
api_key="test-key",
|
||||
)
|
||||
model_access_groups = {
|
||||
"group-a": ["model1", "model2"],
|
||||
}
|
||||
|
||||
_ = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_model_list=["model1", "model2"],
|
||||
model_access_groups=model_access_groups,
|
||||
include_model_access_groups=False,
|
||||
)
|
||||
# The original models list on the auth object must be unchanged
|
||||
assert user_api_key_dict.models == original_models
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key_models,team_models,proxy_model_list,model_list,expected",
|
||||
[
|
||||
|
||||
@ -2410,12 +2410,13 @@ def test_mapped_pass_through_routes_with_server_root_path():
|
||||
)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_passthrough_preserves_boundary():
|
||||
"""
|
||||
Test that multipart/form-data requests through passthrough preserve the boundary
|
||||
and can be correctly parsed by the upstream server.
|
||||
|
||||
|
||||
Regression test for multipart boundary stripping issue.
|
||||
"""
|
||||
from io import BytesIO
|
||||
@ -2426,41 +2427,41 @@ async def test_multipart_passthrough_preserves_boundary():
|
||||
mock_response.headers = httpx.Headers({"content-type": "application/json"})
|
||||
mock_response.aread = AsyncMock(return_value=b'{"filename": "test.txt", "size": 17}')
|
||||
mock_response.text = '{"filename": "test.txt", "size": 17}'
|
||||
|
||||
|
||||
async def mock_httpx_request(method, url, **kwargs):
|
||||
# Verify that files parameter is passed (not json)
|
||||
assert "files" in kwargs, "Files should be passed for multipart requests"
|
||||
assert "file" in kwargs["files"], "File field should be in files dict"
|
||||
|
||||
|
||||
# Verify content-type is NOT in headers (httpx will set it with correct boundary)
|
||||
headers = kwargs.get("headers", {})
|
||||
assert "content-type" not in headers, "content-type should be removed for multipart"
|
||||
|
||||
|
||||
filename, content, content_type = kwargs["files"]["file"]
|
||||
assert filename == "test.txt"
|
||||
assert content == b"test file content"
|
||||
assert content_type == "text/plain"
|
||||
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
async_client = MagicMock()
|
||||
async_client.request = AsyncMock(side_effect=mock_httpx_request)
|
||||
|
||||
|
||||
# Create mock request
|
||||
request = MagicMock(spec=Request)
|
||||
request.method = "POST"
|
||||
request.headers = Headers({"content-type": "multipart/form-data; boundary=test123"})
|
||||
|
||||
|
||||
# Mock form data
|
||||
file_content = b"test file content"
|
||||
file = BytesIO(file_content)
|
||||
headers = Headers({"content-type": "text/plain"})
|
||||
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
|
||||
upload_file.read = AsyncMock(return_value=file_content)
|
||||
|
||||
|
||||
form_data = {"file": upload_file}
|
||||
request.form = AsyncMock(return_value=form_data)
|
||||
|
||||
|
||||
# Test the multipart handler directly
|
||||
response = await HttpPassThroughEndpointHelpers.make_multipart_http_request(
|
||||
request=request,
|
||||
@ -2469,7 +2470,7 @@ async def test_multipart_passthrough_preserves_boundary():
|
||||
headers={},
|
||||
requested_query_params=None,
|
||||
)
|
||||
|
||||
|
||||
# Verify the response
|
||||
assert response.status_code == 200
|
||||
async_client.request.assert_called_once()
|
||||
|
||||
@ -1071,9 +1071,10 @@ def test_spend_logs_redacts_request_and_response_when_turn_off_message_logging_e
|
||||
response_result = _get_response_for_spend_logs_payload(payload=payload, kwargs=kwargs)
|
||||
|
||||
# When redaction is enabled and response is a dict (not ModelResponse),
|
||||
# perform_redaction returns {"text": "redacted-by-litellm"}
|
||||
# perform_redaction redacts content in-place within the choices structure
|
||||
parsed_response = json.loads(response_result)
|
||||
assert parsed_response == {"text": "redacted-by-litellm"}
|
||||
assert parsed_response["choices"][0]["message"]["content"] == "redacted-by-litellm"
|
||||
assert parsed_response["choices"][0]["message"]["role"] == "assistant"
|
||||
|
||||
|
||||
@patch("litellm.secret_managers.main.get_secret_bool")
|
||||
|
||||
142
tests/test_litellm/proxy/test_openapi_schema_validation.py
Normal file
142
tests/test_litellm/proxy/test_openapi_schema_validation.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Test that the OpenAPI schema generated by FastAPI is valid for specific endpoints.
|
||||
|
||||
Validates fixes for:
|
||||
- /spend/calculate response schema (must use proper OpenAPI 3.x content wrapper)
|
||||
- /credentials/by_model/{model_id} path parameter (must not leak credential_name)
|
||||
|
||||
Related issue: https://github.com/BerriAI/litellm/issues/21305
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestSpendCalculateOpenAPISchema:
|
||||
"""Test /spend/calculate response schema is valid OpenAPI 3.x."""
|
||||
|
||||
def test_response_schema_has_description(self):
|
||||
"""The 200 response must have a 'description' field per OpenAPI 3.x spec."""
|
||||
from litellm.proxy.spend_tracking.spend_management_endpoints import router
|
||||
|
||||
for route in router.routes:
|
||||
if hasattr(route, "path") and route.path == "/spend/calculate":
|
||||
responses = route.responses or {}
|
||||
response_200 = responses.get(200, {})
|
||||
assert "description" in response_200, (
|
||||
"/spend/calculate 200 response must have a 'description' field"
|
||||
)
|
||||
break
|
||||
else:
|
||||
pytest.fail("/spend/calculate route not found in router")
|
||||
|
||||
def test_response_schema_has_content_wrapper(self):
|
||||
"""The 200 response must use 'content' wrapper, not bare properties."""
|
||||
from litellm.proxy.spend_tracking.spend_management_endpoints import router
|
||||
|
||||
for route in router.routes:
|
||||
if hasattr(route, "path") and route.path == "/spend/calculate":
|
||||
responses = route.responses or {}
|
||||
response_200 = responses.get(200, {})
|
||||
# Must NOT have 'cost' as a top-level key (invalid OpenAPI)
|
||||
assert "cost" not in response_200, (
|
||||
"/spend/calculate 200 response must not have 'cost' as a "
|
||||
"top-level property - use 'content' wrapper instead"
|
||||
)
|
||||
# Must have 'content' wrapper
|
||||
assert "content" in response_200, (
|
||||
"/spend/calculate 200 response must have a 'content' field"
|
||||
)
|
||||
content = response_200["content"]
|
||||
assert "application/json" in content
|
||||
assert "schema" in content["application/json"]
|
||||
break
|
||||
else:
|
||||
pytest.fail("/spend/calculate route not found in router")
|
||||
|
||||
|
||||
class TestCredentialEndpointsOpenAPISchema:
|
||||
"""Test /credentials endpoints have correct path parameters."""
|
||||
|
||||
def test_by_name_and_by_model_are_separate_handlers(self):
|
||||
"""
|
||||
/credentials/by_name/{credential_name} and /credentials/by_model/{model_id}
|
||||
must be separate handler functions so each only declares its own path params.
|
||||
"""
|
||||
from litellm.proxy.credential_endpoints.endpoints import router
|
||||
|
||||
by_name_routes = []
|
||||
by_model_routes = []
|
||||
for route in router.routes:
|
||||
if not hasattr(route, "path"):
|
||||
continue
|
||||
if "by_name" in route.path:
|
||||
by_name_routes.append(route)
|
||||
elif "by_model" in route.path:
|
||||
by_model_routes.append(route)
|
||||
|
||||
assert len(by_name_routes) == 1, "Expected exactly one by_name route"
|
||||
assert len(by_model_routes) == 1, "Expected exactly one by_model route"
|
||||
|
||||
# They must be different endpoint functions
|
||||
by_name_endpoint = by_name_routes[0].endpoint
|
||||
by_model_endpoint = by_model_routes[0].endpoint
|
||||
assert by_name_endpoint is not by_model_endpoint, (
|
||||
"by_name and by_model must be separate handler functions "
|
||||
"to avoid path parameter conflicts in OpenAPI spec"
|
||||
)
|
||||
|
||||
def test_by_model_route_does_not_require_credential_name(self):
|
||||
"""
|
||||
The /credentials/by_model/{model_id} route must NOT have
|
||||
credential_name as a parameter.
|
||||
"""
|
||||
import inspect
|
||||
from litellm.proxy.credential_endpoints.endpoints import (
|
||||
get_credential_by_model,
|
||||
)
|
||||
|
||||
sig = inspect.signature(get_credential_by_model)
|
||||
param_names = list(sig.parameters.keys())
|
||||
assert "credential_name" not in param_names, (
|
||||
"get_credential_by_model must not have a credential_name parameter"
|
||||
)
|
||||
|
||||
def test_by_name_route_does_not_require_model_id(self):
|
||||
"""
|
||||
The /credentials/by_name/{credential_name} route must NOT have
|
||||
model_id as a parameter.
|
||||
"""
|
||||
import inspect
|
||||
from litellm.proxy.credential_endpoints.endpoints import (
|
||||
get_credential_by_name,
|
||||
)
|
||||
|
||||
sig = inspect.signature(get_credential_by_name)
|
||||
param_names = list(sig.parameters.keys())
|
||||
assert "model_id" not in param_names, (
|
||||
"get_credential_by_name must not have a model_id parameter"
|
||||
)
|
||||
|
||||
def test_by_model_has_model_id_path_param(self):
|
||||
"""The by_model handler must accept model_id as a path parameter."""
|
||||
import inspect
|
||||
from litellm.proxy.credential_endpoints.endpoints import (
|
||||
get_credential_by_model,
|
||||
)
|
||||
|
||||
sig = inspect.signature(get_credential_by_model)
|
||||
assert "model_id" in sig.parameters, (
|
||||
"get_credential_by_model must have a model_id parameter"
|
||||
)
|
||||
|
||||
def test_by_name_has_credential_name_path_param(self):
|
||||
"""The by_name handler must accept credential_name as a path parameter."""
|
||||
import inspect
|
||||
from litellm.proxy.credential_endpoints.endpoints import (
|
||||
get_credential_by_name,
|
||||
)
|
||||
|
||||
sig = inspect.signature(get_credential_by_name)
|
||||
assert "credential_name" in sig.parameters, (
|
||||
"get_credential_by_name must have a credential_name parameter"
|
||||
)
|
||||
@ -1774,3 +1774,128 @@ class TestStreamingIDConsistency:
|
||||
# Verify it matches the cached ID
|
||||
assert iterator._cached_item_id is not None
|
||||
assert iterator._cached_item_id == text_done_id
|
||||
|
||||
def test_parallel_tool_calls_merged_into_single_assistant_message(self):
|
||||
"""
|
||||
Regression test: multi-turn parallel tool calls via the Responses API must
|
||||
produce a single assistant message with all tool_calls, not one assistant
|
||||
message per function_call item.
|
||||
|
||||
When the model responds with two parallel tool calls (e.g. get_weather for
|
||||
SF and NYC), the next Responses API request includes two consecutive
|
||||
function_call items followed by two function_call_output items.
|
||||
|
||||
Without the fix each function_call becomes its own assistant message,
|
||||
producing back-to-back assistant messages that Anthropic/Vertex AI rejects:
|
||||
"tool_use ids were found without tool_result blocks immediately after".
|
||||
"""
|
||||
input_items = [
|
||||
{"type": "message", "role": "user", "content": "Weather in SF and NYC?"},
|
||||
# Two parallel tool calls from the previous assistant response
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": "toolu_01",
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "SF"}',
|
||||
},
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": "toolu_02",
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
# Tool results
|
||||
{"type": "function_call_output", "call_id": "toolu_01", "output": "72°F"},
|
||||
{"type": "function_call_output", "call_id": "toolu_02", "output": "55°F"},
|
||||
]
|
||||
|
||||
messages = LiteLLMCompletionResponsesConfig._transform_response_input_param_to_chat_completion_message(
|
||||
input=input_items
|
||||
)
|
||||
|
||||
roles = [
|
||||
m.get("role") if isinstance(m, dict) else getattr(m, "role", None)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
# Must not have two consecutive assistant messages
|
||||
for i in range(len(roles) - 1):
|
||||
assert not (
|
||||
roles[i] == "assistant" and roles[i + 1] == "assistant"
|
||||
), f"Consecutive assistant messages at indices {i} and {i+1}: {roles}"
|
||||
|
||||
# The single assistant message must contain BOTH tool_calls
|
||||
assistant_messages = [
|
||||
m for m in messages
|
||||
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None))
|
||||
== "assistant"
|
||||
]
|
||||
assert len(assistant_messages) == 1, (
|
||||
f"Expected 1 assistant message, got {len(assistant_messages)}"
|
||||
)
|
||||
|
||||
assistant_msg = assistant_messages[0]
|
||||
tool_calls = (
|
||||
assistant_msg.get("tool_calls")
|
||||
if isinstance(assistant_msg, dict)
|
||||
else getattr(assistant_msg, "tool_calls", None)
|
||||
)
|
||||
assert tool_calls is not None and len(tool_calls) == 2, (
|
||||
f"Expected 2 tool_calls in the merged assistant message, got: {tool_calls}"
|
||||
)
|
||||
|
||||
call_ids = [
|
||||
(tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None))
|
||||
for tc in tool_calls
|
||||
]
|
||||
assert "toolu_01" in call_ids, f"toolu_01 missing from tool_calls: {call_ids}"
|
||||
assert "toolu_02" in call_ids, f"toolu_02 missing from tool_calls: {call_ids}"
|
||||
|
||||
# Both tool messages must be present
|
||||
tool_messages = [
|
||||
m for m in messages
|
||||
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None))
|
||||
== "tool"
|
||||
]
|
||||
assert len(tool_messages) == 2, (
|
||||
f"Expected 2 tool messages, got {len(tool_messages)}"
|
||||
)
|
||||
|
||||
def test_single_tool_call_still_works_after_merge_fix(self):
|
||||
"""
|
||||
Ensure the parallel-tool-call merging fix does not break the existing
|
||||
single-tool-call path.
|
||||
"""
|
||||
input_items = [
|
||||
{"type": "message", "role": "user", "content": "Weather in SF?"},
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": "toolu_01",
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "SF"}',
|
||||
},
|
||||
{"type": "function_call_output", "call_id": "toolu_01", "output": "72°F"},
|
||||
]
|
||||
|
||||
messages = LiteLLMCompletionResponsesConfig._transform_response_input_param_to_chat_completion_message(
|
||||
input=input_items
|
||||
)
|
||||
|
||||
roles = [
|
||||
m.get("role") if isinstance(m, dict) else getattr(m, "role", None)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
assert "user" in roles
|
||||
assert "assistant" in roles
|
||||
assert "tool" in roles
|
||||
|
||||
assistant_messages = [m for m in messages if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "assistant"]
|
||||
assert len(assistant_messages) == 1
|
||||
|
||||
tool_calls = (
|
||||
assistant_messages[0].get("tool_calls")
|
||||
if isinstance(assistant_messages[0], dict)
|
||||
else getattr(assistant_messages[0], "tool_calls", None)
|
||||
)
|
||||
assert tool_calls is not None and len(tool_calls) == 1
|
||||
|
||||
251
tests/test_litellm/test_router_retry_non_retryable_errors.py
Normal file
251
tests/test_litellm/test_router_retry_non_retryable_errors.py
Normal file
@ -0,0 +1,251 @@
|
||||
"""
|
||||
Test that the Router retry loop correctly handles non-retryable errors.
|
||||
|
||||
Verifies that:
|
||||
1. Non-retryable errors (e.g., 400 ContextWindowExceeded) inside the retry loop
|
||||
break out immediately instead of being swallowed.
|
||||
2. original_exception is updated to the latest error, not stuck on the first.
|
||||
3. Retryable errors (e.g., 429 RateLimitError) still retry normally.
|
||||
|
||||
Regression tests for https://github.com/BerriAI/litellm/issues/21343
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
|
||||
def _make_rate_limit_error(message="Rate limited"):
|
||||
"""Create a RateLimitError for testing."""
|
||||
return litellm.RateLimitError(
|
||||
message=message,
|
||||
llm_provider="bedrock",
|
||||
model="anthropic.claude-v2",
|
||||
)
|
||||
|
||||
|
||||
def _make_context_window_error(message="prompt is too long: 1205821 tokens > 200000"):
|
||||
"""Create a ContextWindowExceededError for testing."""
|
||||
return litellm.ContextWindowExceededError(
|
||||
message=message,
|
||||
llm_provider="vertex_ai",
|
||||
model="claude-3-opus",
|
||||
)
|
||||
|
||||
|
||||
def _make_bad_request_error(message="Invalid request"):
|
||||
"""Create a BadRequestError for testing."""
|
||||
return litellm.BadRequestError(
|
||||
message=message,
|
||||
llm_provider="openai",
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
|
||||
def _make_not_found_error(message="Model not found"):
|
||||
"""Create a NotFoundError for testing."""
|
||||
return litellm.NotFoundError(
|
||||
message=message,
|
||||
llm_provider="openai",
|
||||
model="gpt-99",
|
||||
)
|
||||
|
||||
|
||||
def _create_router(num_retries=2):
|
||||
"""Create a Router with two deployments for testing."""
|
||||
return Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "test-model",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4",
|
||||
"api_key": "fake-key-1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "test-model",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4",
|
||||
"api_key": "fake-key-2",
|
||||
},
|
||||
},
|
||||
],
|
||||
num_retries=num_retries,
|
||||
)
|
||||
|
||||
|
||||
def _base_kwargs():
|
||||
"""Return kwargs required by async_function_with_retries."""
|
||||
return {
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"original_function": AsyncMock(),
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_retryable_error_in_retry_loop_raises_immediately():
|
||||
"""
|
||||
When a non-retryable error (400 ContextWindowExceeded) occurs inside the
|
||||
retry loop, the router should raise it immediately instead of swallowing it
|
||||
and raising the original error.
|
||||
|
||||
Scenario: First call -> 429, Retry -> 400 (non-retryable)
|
||||
Expected: ContextWindowExceededError is raised, NOT RateLimitError
|
||||
"""
|
||||
router = _create_router(num_retries=2)
|
||||
|
||||
rate_limit_error = _make_rate_limit_error()
|
||||
context_window_error = _make_context_window_error()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_make_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise rate_limit_error
|
||||
else:
|
||||
raise context_window_error
|
||||
|
||||
with patch.object(router, "make_call", side_effect=mock_make_call), \
|
||||
patch.object(router, "_async_get_healthy_deployments",
|
||||
return_value=(["d1", "d2"], ["d1", "d2"])), \
|
||||
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
|
||||
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
|
||||
with pytest.raises(litellm.ContextWindowExceededError):
|
||||
await router.async_function_with_retries(
|
||||
num_retries=2,
|
||||
**_base_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_request_error_in_retry_loop_raises_immediately():
|
||||
"""
|
||||
A generic 400 BadRequestError inside the retry loop should also break out
|
||||
immediately since 400 is not retryable.
|
||||
"""
|
||||
router = _create_router(num_retries=2)
|
||||
|
||||
rate_limit_error = _make_rate_limit_error()
|
||||
bad_request_error = _make_bad_request_error()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_make_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise rate_limit_error
|
||||
else:
|
||||
raise bad_request_error
|
||||
|
||||
with patch.object(router, "make_call", side_effect=mock_make_call), \
|
||||
patch.object(router, "_async_get_healthy_deployments",
|
||||
return_value=(["d1", "d2"], ["d1", "d2"])), \
|
||||
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
|
||||
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
|
||||
with pytest.raises(litellm.BadRequestError):
|
||||
await router.async_function_with_retries(
|
||||
num_retries=2,
|
||||
**_base_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_original_exception_updated_to_latest_error():
|
||||
"""
|
||||
When all retries are exhausted with retryable errors, the LAST error
|
||||
should be raised, not the first one.
|
||||
"""
|
||||
router = _create_router(num_retries=2)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_make_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_rate_limit_error(f"Rate limit attempt {call_count}")
|
||||
|
||||
with patch.object(router, "make_call", side_effect=mock_make_call), \
|
||||
patch.object(router, "_async_get_healthy_deployments",
|
||||
return_value=(["d1", "d2"], ["d1", "d2"])), \
|
||||
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
|
||||
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
|
||||
with pytest.raises(litellm.RateLimitError) as exc_info:
|
||||
await router.async_function_with_retries(
|
||||
num_retries=2,
|
||||
**_base_kwargs(),
|
||||
)
|
||||
# Should be the LAST error, not the first
|
||||
assert "Rate limit attempt 3" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_errors_still_retry_normally():
|
||||
"""
|
||||
Retryable errors (429 RateLimitError) should still be retried the
|
||||
configured number of times before raising.
|
||||
"""
|
||||
router = _create_router(num_retries=3)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_make_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_rate_limit_error(f"Rate limit attempt {call_count}")
|
||||
|
||||
with patch.object(router, "make_call", side_effect=mock_make_call), \
|
||||
patch.object(router, "_async_get_healthy_deployments",
|
||||
return_value=(["d1", "d2"], ["d1", "d2"])), \
|
||||
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
|
||||
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
|
||||
with pytest.raises(litellm.RateLimitError):
|
||||
await router.async_function_with_retries(
|
||||
num_retries=3,
|
||||
**_base_kwargs(),
|
||||
)
|
||||
|
||||
# Initial call + 3 retries = 4 total calls
|
||||
assert call_count == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_error_in_retry_loop_raises_immediately():
|
||||
"""
|
||||
A 404 NotFoundError inside the retry loop should break out immediately.
|
||||
"""
|
||||
router = _create_router(num_retries=2)
|
||||
|
||||
rate_limit_error = _make_rate_limit_error()
|
||||
not_found_error = _make_not_found_error()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_make_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise rate_limit_error
|
||||
else:
|
||||
raise not_found_error
|
||||
|
||||
with patch.object(router, "make_call", side_effect=mock_make_call), \
|
||||
patch.object(router, "_async_get_healthy_deployments",
|
||||
return_value=(["d1", "d2"], ["d1", "d2"])), \
|
||||
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
|
||||
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
|
||||
with pytest.raises(litellm.NotFoundError):
|
||||
await router.async_function_with_retries(
|
||||
num_retries=2,
|
||||
**_base_kwargs(),
|
||||
)
|
||||
|
||||
# Only 2 calls: initial + first retry that hits non-retryable
|
||||
assert call_count == 2
|
||||
@ -275,8 +275,8 @@ it("should display user email correctly", async () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should show skeleton loaders when isLoading is true", () => {
|
||||
// Mock loading state
|
||||
it("should show loading message only on initial load (isPending)", () => {
|
||||
// Mock initial loading state
|
||||
mockUseKeys.mockReturnValue({
|
||||
data: null,
|
||||
isPending: true,
|
||||
@ -296,7 +296,7 @@ it("should show skeleton loaders when isLoading is true", () => {
|
||||
|
||||
renderWithProviders(<VirtualKeysTable {...mockProps} />);
|
||||
|
||||
// Check that loading message is shown
|
||||
// Check that loading message is shown on initial load
|
||||
expect(screen.getByText("🚅 Loading keys...")).toBeInTheDocument();
|
||||
|
||||
// Check that actual key data is not shown
|
||||
@ -810,3 +810,79 @@ describe("pagination display – total count and page count", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("refetch button", () => {
|
||||
it("should show Fetch button in normal state", () => {
|
||||
renderWithProviders(<VirtualKeysTable {...defaultMockProps} />);
|
||||
|
||||
const fetchButton = screen.getByTitle("Fetch data");
|
||||
expect(fetchButton).toBeInTheDocument();
|
||||
expect(fetchButton).not.toBeDisabled();
|
||||
expect(screen.getByText("Fetch")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show Fetching state and keep table data visible during refetch", () => {
|
||||
mockUseKeys.mockReturnValue({
|
||||
data: {
|
||||
keys: [mockKey],
|
||||
total_count: 1,
|
||||
current_page: 1,
|
||||
total_pages: 1,
|
||||
} as KeysResponse,
|
||||
isPending: false,
|
||||
isFetching: true,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
renderWithProviders(<VirtualKeysTable {...defaultMockProps} />);
|
||||
|
||||
// Button should show "Fetching" and be disabled
|
||||
expect(screen.getByText("Fetching")).toBeInTheDocument();
|
||||
const fetchButton = screen.getByTitle("Fetch data");
|
||||
expect(fetchButton).toBeDisabled();
|
||||
|
||||
// Table data should still be visible (stale data)
|
||||
expect(screen.getByText("Test Key Alias")).toBeInTheDocument();
|
||||
|
||||
// "Loading keys..." should NOT appear during refetch
|
||||
expect(screen.queryByText("🚅 Loading keys...")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call refetch when Fetch button is clicked", () => {
|
||||
const mockRefetch = vi.fn();
|
||||
mockUseKeys.mockReturnValue({
|
||||
data: {
|
||||
keys: [mockKey],
|
||||
total_count: 1,
|
||||
current_page: 1,
|
||||
total_pages: 1,
|
||||
} as KeysResponse,
|
||||
isPending: false,
|
||||
isFetching: false,
|
||||
refetch: mockRefetch,
|
||||
} as any);
|
||||
|
||||
renderWithProviders(<VirtualKeysTable {...defaultMockProps} />);
|
||||
|
||||
const fetchButton = screen.getByTitle("Fetch data");
|
||||
fireEvent.click(fetchButton);
|
||||
|
||||
expect(mockRefetch).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should show Fetch button enabled on error so user can retry", () => {
|
||||
mockUseKeys.mockReturnValue({
|
||||
data: null,
|
||||
isPending: false,
|
||||
isFetching: false,
|
||||
isError: true,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
renderWithProviders(<VirtualKeysTable {...defaultMockProps} />);
|
||||
|
||||
const fetchButton = screen.getByTitle("Fetch data");
|
||||
expect(fetchButton).not.toBeDisabled();
|
||||
expect(screen.getByText("Fetch")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@ -85,6 +85,7 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
|
||||
data: keys,
|
||||
isPending: isLoading,
|
||||
isFetching,
|
||||
isError,
|
||||
refetch,
|
||||
} = useKeys(tablePagination.pageIndex + 1, tablePagination.pageSize, {
|
||||
sortBy: sortBy || undefined,
|
||||
@ -102,6 +103,15 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
|
||||
organizations,
|
||||
});
|
||||
|
||||
// Defer the transition so the button stays in loading state until the table
|
||||
// has rendered with the new data (mirrors the spend-logs pattern)
|
||||
const isFetchingDeferred = useDeferredValue(isFetching);
|
||||
const isButtonLoading = (isFetching || isFetchingDeferred) && !isError;
|
||||
|
||||
const handleRefresh = () => {
|
||||
refetch();
|
||||
};
|
||||
|
||||
const totalCount = filteredTotalCount ?? keys?.total_count ?? 0;
|
||||
|
||||
// Add a useEffect to call refresh when a key is created
|
||||
@ -669,16 +679,28 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between w-full mb-4">
|
||||
{isLoading || isFetching ? (
|
||||
<Skeleton.Node active style={{ width: 200, height: 20 }} />
|
||||
) : (
|
||||
<span className="inline-flex text-sm text-gray-700">
|
||||
Showing {rangeLabel} of {totalCount} results
|
||||
</span>
|
||||
)}
|
||||
<div className="inline-flex items-center gap-2">
|
||||
{isLoading ? (
|
||||
<Skeleton.Node active style={{ width: 200, height: 20 }} />
|
||||
) : (
|
||||
<span className="inline-flex text-sm text-gray-700">
|
||||
Showing {rangeLabel} of {totalCount} results
|
||||
</span>
|
||||
)}
|
||||
|
||||
<AntButton
|
||||
type="default"
|
||||
icon={<SyncOutlined spin={isButtonLoading} />}
|
||||
onClick={handleRefresh}
|
||||
disabled={isButtonLoading}
|
||||
title="Fetch data"
|
||||
>
|
||||
{isButtonLoading ? "Fetching" : "Fetch"}
|
||||
</AntButton>
|
||||
</div>
|
||||
|
||||
<div className="inline-flex items-center gap-2">
|
||||
{isLoading || isFetching ? (
|
||||
{isLoading ? (
|
||||
<Skeleton.Node active style={{ width: 74, height: 20 }} />
|
||||
) : (
|
||||
<span className="text-sm text-gray-700">
|
||||
@ -686,24 +708,24 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
|
||||
</span>
|
||||
)}
|
||||
|
||||
{isLoading || isFetching ? (
|
||||
{isLoading ? (
|
||||
<Skeleton.Button active size="small" style={{ width: 84, height: 30 }} />
|
||||
) : (
|
||||
<button
|
||||
onClick={() => table.previousPage()}
|
||||
disabled={isLoading || isFetching || !table.getCanPreviousPage()}
|
||||
disabled={isLoading || !table.getCanPreviousPage()}
|
||||
className="px-3 py-1 text-sm border rounded-md hover:bg-gray-50 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Previous
|
||||
</button>
|
||||
)}
|
||||
|
||||
{isLoading || isFetching ? (
|
||||
{isLoading ? (
|
||||
<Skeleton.Button active size="small" style={{ width: 58, height: 30 }} />
|
||||
) : (
|
||||
<button
|
||||
onClick={() => table.nextPage()}
|
||||
disabled={isLoading || isFetching || !table.getCanNextPage()}
|
||||
disabled={isLoading || !table.getCanNextPage()}
|
||||
className="px-3 py-1 text-sm border rounded-md hover:bg-gray-50 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Next
|
||||
@ -788,7 +810,7 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
|
||||
))}
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{isLoading || isFetching ? (
|
||||
{isLoading ? (
|
||||
<TableRow>
|
||||
<TableCell colSpan={columns.length} className="h-8 text-center">
|
||||
<div className="text-center text-gray-500">
|
||||
|
||||
Loading…
Reference in New Issue
Block a user