Reapply "feat: add model_cost aliases expansion support"

This reverts commit 3d2df7e8b5.
This commit is contained in:
Chesars 2026-03-12 13:36:57 -03:00
parent fa68d69bcf
commit feed274aa3
39 changed files with 2315 additions and 195 deletions

View File

@ -348,6 +348,8 @@ class DualCache(BaseCache):
)
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
@ -369,6 +371,8 @@ class DualCache(BaseCache):
)
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
await self.in_memory_cache.async_set_cache_pipeline(
cache_list=cache_list, **kwargs
)

View File

@ -398,6 +398,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
ResponseOutputMessage,
ResponseReasoningItem,
)
from openai.types.responses.response_output_item import ResponseApplyPatchToolCall
from litellm.types.utils import Choices, Message
@ -456,6 +457,18 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
accumulated_tool_calls.append(tool_call_dict)
tool_call_index += 1
elif isinstance(item, ResponseApplyPatchToolCall):
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_apply_patch_tool_call_to_chat_completion_tool_call(
tool_call_item=item,
index=tool_call_index,
)
accumulated_tool_calls.append(tool_call_dict)
tool_call_index += 1
elif isinstance(item, dict) and handle_raw_dict_callback is not None:
# Handle raw dict responses (e.g., from GPT-5 Codex)
choice, index = handle_raw_dict_callback(item=item, index=index)

View File

@ -64,12 +64,10 @@ def duration_in_seconds(duration: str) -> int:
now = time.time()
current_time = datetime.fromtimestamp(now)
if current_time.month == 12:
target_year = current_time.year + 1
target_month = 1
else:
target_year = current_time.year
target_month = current_time.month + value
# Calculate target month and year, handling overflow past December
total_months = current_time.month - 1 + value # 0-indexed months
target_year = current_time.year + total_months // 12
target_month = total_months % 12 + 1 # back to 1-indexed
# Determine the day to set for next month
target_day = current_time.day

View File

@ -75,6 +75,53 @@ def _redact_responses_api_output(output_items):
summary_item.text = "redacted-by-litellm"
def _redact_standard_logging_object(model_call_details: dict):
"""Redact messages and response inside standard_logging_object if present."""
standard_logging_object = model_call_details.get("standard_logging_object")
if standard_logging_object is None:
return
redacted_str = "redacted-by-litellm"
if standard_logging_object.get("messages") is not None:
standard_logging_object["messages"] = [
{"role": "user", "content": redacted_str}
]
response = standard_logging_object.get("response")
if response is not None:
if isinstance(response, dict) and "output" in response:
# ResponsesAPIResponse format - redact content in output items
if isinstance(response.get("output"), list):
for output_item in response["output"]:
if isinstance(output_item, dict) and "content" in output_item:
if isinstance(output_item["content"], list):
for content_item in output_item["content"]:
if (
isinstance(content_item, dict)
and "text" in content_item
):
content_item["text"] = redacted_str
elif isinstance(response, dict) and "choices" in response:
# ModelResponse dict format - redact content in choices
if isinstance(response.get("choices"), list):
for choice in response["choices"]:
if isinstance(choice, dict):
if "message" in choice and isinstance(choice["message"], dict):
choice["message"]["content"] = redacted_str
if "audio" in choice["message"]:
choice["message"]["audio"] = None
elif "delta" in choice and isinstance(choice["delta"], dict):
choice["delta"]["content"] = redacted_str
if "audio" in choice["delta"]:
choice["delta"]["audio"] = None
elif isinstance(response, str):
standard_logging_object["response"] = redacted_str
else:
# For other formats (empty dict, None, etc.), use simple text format
standard_logging_object["response"] = {"text": redacted_str}
def perform_redaction(model_call_details: dict, result):
"""
Performs the actual redaction on the logging object and result.

View File

@ -51,6 +51,7 @@ from litellm.types.llms.openai import (
)
from litellm.types.utils import (
ChatCompletionMessageToolCall,
CompletionTokensDetailsWrapper,
Function,
Message,
ModelResponse,
@ -63,6 +64,7 @@ from litellm.utils import (
has_tool_call_blocks,
last_assistant_with_tool_calls_has_no_thinking_blocks,
supports_reasoning,
token_counter,
)
from ..common_utils import (
@ -1637,7 +1639,11 @@ class AmazonConverseConfig(BaseConfig):
thinking_blocks_list.append(_redacted_block)
return thinking_blocks_list
def _transform_usage(self, usage: ConverseTokenUsageBlock) -> Usage:
def _transform_usage(
self,
usage: ConverseTokenUsageBlock,
reasoning_content: Optional[str] = None,
) -> Usage:
input_tokens = usage["inputTokens"]
output_tokens = usage["outputTokens"]
total_tokens = usage["totalTokens"]
@ -1654,6 +1660,19 @@ class AmazonConverseConfig(BaseConfig):
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
reasoning_tokens = (
token_counter(text=reasoning_content, count_response_tokens=True)
if reasoning_content
else 0
)
completion_tokens_details = CompletionTokensDetailsWrapper(
reasoning_tokens=reasoning_tokens,
text_tokens=(
output_tokens - reasoning_tokens
if reasoning_tokens > 0
else output_tokens
),
)
openai_usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
@ -1661,6 +1680,7 @@ class AmazonConverseConfig(BaseConfig):
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_tokens_details,
)
return openai_usage
@ -1997,7 +2017,10 @@ class AmazonConverseConfig(BaseConfig):
chat_completion_message["tool_calls"] = filtered_tools
## CALCULATING USAGE - bedrock returns usage in the headers
usage = self._transform_usage(completion_response["usage"])
usage = self._transform_usage(
completion_response["usage"],
reasoning_content=chat_completion_message.get("reasoning_content"),
)
## HANDLE TOOL CALLS
_message = Message(**chat_completion_message)

View File

@ -429,8 +429,11 @@ class FireworksAIConfig(OpenAIGPTConfig):
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
)
base = api_base.rstrip("/")
if base.endswith("/v1"):
base = base[: -len("/v1")]
response = litellm.module_level_client.get(
url=f"{api_base}/v1/accounts/{account_id}/models",
url=f"{base}/v1/accounts/{account_id}/models",
headers={"Authorization": f"Bearer {api_key}"},
)

View File

@ -583,35 +583,17 @@ class SagemakerLLM(BaseAWSLLM):
### BOTO3 INIT
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
# Use _load_credentials to support role assumption (aws_role_name, aws_session_name)
credentials, aws_region_name = self._load_credentials(optional_params)
if aws_access_key_id is not None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or aws_region_name # get region from config file if specified
or "us-west-2" # default to us-west-2 if region not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# Create boto3 session with the loaded credentials
session = boto3.Session(
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
region_name=aws_region_name,
)
client = session.client(service_name="sagemaker-runtime")
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
@ -628,7 +610,9 @@ class SagemakerLLM(BaseAWSLLM):
#### EMBEDDING LOGIC
# Transform request based on model type
provider_config = SagemakerEmbeddingConfig.get_model_config(model)
request_data = provider_config.transform_embedding_request(model, input, optional_params, {})
request_data = provider_config.transform_embedding_request(
model, input, optional_params, {}
)
data = json.dumps(request_data).encode("utf-8")
## LOGGING
@ -673,19 +657,19 @@ class SagemakerLLM(BaseAWSLLM):
)
print_verbose(f"raw model_response: {response}")
# Transform response based on model type
from httpx import Response as HttpxResponse
# Create a mock httpx Response object for the transformation
mock_response = HttpxResponse(
status_code=200,
content=json.dumps(response).encode('utf-8'),
headers={"content-type": "application/json"}
content=json.dumps(response).encode("utf-8"),
headers={"content-type": "application/json"},
)
model_response = EmbeddingResponse()
# Use the request_data that was already transformed above
return provider_config.transform_embedding_response(
model=model,
@ -695,5 +679,5 @@ class SagemakerLLM(BaseAWSLLM):
api_key=None,
request_data=request_data,
optional_params=optional_params,
litellm_params=litellm_params or {}
litellm_params=litellm_params or {},
)

View File

@ -593,6 +593,10 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
raise e
# Keys that LiteLLM consumes internally and must never be forwarded to the
_LITELLM_INTERNAL_EXTRA_BODY_KEYS: frozenset = frozenset({"cache", "tags"})
def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None:
"""Pop extra_body from optional_params and shallow-merge into data, deep-merging dict values."""
extra_body: Optional[dict] = optional_params.pop("extra_body", None)

View File

@ -718,6 +718,7 @@ if MCP_AVAILABLE:
Checks both the full tool name and unprefixed version (without server prefix).
This allows users to configure simple tool names regardless of prefixing.
Comparison is case-insensitive to handle OpenAPI operationIds that may be in camelCase.
Args:
tool_name: The tool name to check (may be prefixed like "server-tool_name")
@ -730,13 +731,15 @@ if MCP_AVAILABLE:
split_server_prefix_from_name,
)
# Check if the full name is in the list
if tool_name in filter_list:
# Normalize filter list to lowercase for case-insensitive comparison
filter_list_lower = [f.lower() for f in filter_list]
if tool_name.lower() in filter_list_lower:
return True
# Check if the unprefixed name is in the list
# Check if the unprefixed name is in the list (case-insensitive)
unprefixed_name, _ = split_server_prefix_from_name(tool_name)
return unprefixed_name in filter_list
return unprefixed_name.lower() in filter_list_lower
def filter_tools_by_allowed_tools(
tools: List[MCPTool],

View File

@ -111,12 +111,19 @@ def get_key_models(
if SpecialModelNames.all_team_models.value in all_models:
all_models = list(user_api_key_dict.team_models) # copy to avoid mutating cached objects
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
all_models = list(proxy_model_list) # copy to avoid mutating caller's list
if include_model_access_groups:
all_models.extend(model_access_groups.keys())
all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups, all_models=all_models
model_access_groups=model_access_groups,
all_models=all_models,
include_model_access_groups=include_model_access_groups,
)
# deduplicate while preserving order
all_models = list(dict.fromkeys(all_models))
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
@ -140,8 +147,8 @@ def get_team_models(
all_models_set.update(team_models)
if SpecialModelNames.all_proxy_models.value in all_models_set:
all_models_set.update(proxy_model_list)
all_models = list(all_models_set)
if include_model_access_groups:
all_models_set.update(model_access_groups.keys())
all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups,
@ -149,6 +156,9 @@ def get_team_models(
include_model_access_groups=include_model_access_groups,
)
# deduplicate while preserving order
all_models = list(dict.fromkeys(all_models))
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
return all_models

View File

@ -142,17 +142,47 @@ async def get_credentials(
tags=["credential management"],
response_model=CredentialItem,
)
async def get_credential_by_name(
request: Request,
fastapi_response: Response,
credential_name: str = Path(..., description="The credential name, percent-decoded; may contain slashes"),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[BETA] endpoint. This might change unexpectedly.
"""
try:
for credential in litellm.credential_list:
if credential.credential_name == credential_name:
masked_credential = CredentialItem(
credential_name=credential.credential_name,
credential_values=_get_masked_values(
credential.credential_values,
unmasked_length=4,
number_of_asterisks=4,
),
credential_info=credential.credential_info,
)
return masked_credential
raise HTTPException(
status_code=404,
detail="Credential not found. Got credential name: " + credential_name,
)
except Exception as e:
verbose_proxy_logger.exception(e)
raise handle_exception_on_proxy(e)
@router.get(
"/credentials/by_model/{model_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["credential management"],
response_model=CredentialItem,
)
async def get_credential(
async def get_credential_by_model(
request: Request,
fastapi_response: Response,
credential_name: str = Path(..., description="The credential name, percent-decoded; may contain slashes"),
model_id: Optional[str] = None,
model_id: str = Path(..., description="The model ID to look up credentials for"),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
@ -161,48 +191,25 @@ async def get_credential(
from litellm.proxy.proxy_server import llm_router
try:
if model_id:
if llm_router is None:
raise HTTPException(status_code=500, detail="LLM router not found")
model = llm_router.get_deployment(model_id)
if model is None:
raise HTTPException(status_code=404, detail="Model not found")
credential_values = llm_router.get_deployment_credentials(model_id)
if credential_values is None:
raise HTTPException(status_code=404, detail="Model not found")
masked_credential_values = _get_masked_values(
credential_values,
unmasked_length=4,
number_of_asterisks=4,
)
credential = CredentialItem(
credential_name="{}-credential-{}".format(model.model_name, model_id),
credential_values=masked_credential_values,
credential_info={},
)
# return credential object
return credential
elif credential_name:
for credential in litellm.credential_list:
if credential.credential_name == credential_name:
masked_credential = CredentialItem(
credential_name=credential.credential_name,
credential_values=_get_masked_values(
credential.credential_values,
unmasked_length=4,
number_of_asterisks=4,
),
credential_info=credential.credential_info,
)
return masked_credential
raise HTTPException(
status_code=404,
detail="Credential not found. Got credential name: " + credential_name,
)
else:
raise HTTPException(
status_code=404, detail="Credential name or model ID required"
)
if llm_router is None:
raise HTTPException(status_code=500, detail="LLM router not found")
model = llm_router.get_deployment(model_id)
if model is None:
raise HTTPException(status_code=404, detail="Model not found")
credential_values = llm_router.get_deployment_credentials(model_id)
if credential_values is None:
raise HTTPException(status_code=404, detail="Model not found")
masked_credential_values = _get_masked_values(
credential_values,
unmasked_length=4,
number_of_asterisks=4,
)
credential = CredentialItem(
credential_name="{}-credential-{}".format(model.model_name, model_id),
credential_values=masked_credential_values,
credential_info={},
)
return credential
except Exception as e:
verbose_proxy_logger.exception(e)
raise handle_exception_on_proxy(e)

View File

@ -2827,21 +2827,6 @@ async def validate_membership(
)
def _unfurl_all_proxy_models(
team_info: LiteLLM_TeamTable, llm_router: Router
) -> LiteLLM_TeamTable:
if (
SpecialModelNames.all_proxy_models.value in team_info.models
and llm_router is not None
):
team_models: set[str] = set() # make set to avoid duplicates
for model in team_info.models:
if model != SpecialModelNames.all_proxy_models.value:
team_models.add(model)
for model in llm_router.get_model_names():
team_models.add(model)
team_info.models = list(team_models)
return team_info
async def _add_team_member_budget_table(
@ -2972,9 +2957,6 @@ async def team_info(
team_info_response_object=_team_info,
)
# ## UNFURL 'all-proxy-models' into the team_info.models list ##
# if llm_router is not None:
# _team_info = _unfurl_all_proxy_models(_team_info, llm_router)
response_object = TeamInfoResponseObject(
team_id=team_id,
team_info=_team_info,

View File

@ -2059,7 +2059,8 @@ class InitPassThroughEndpointHelpers:
"""
## CHECK IF MAPPED PASS THROUGH ENDPOINT
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value:
if route.startswith(mapped_route):
full_mapped_route = InitPassThroughEndpointHelpers._build_full_path_with_root(mapped_route)
if route.startswith(full_mapped_route):
return True
# Fast path: check if any registered route key contains this path

View File

@ -1461,11 +1461,21 @@ async def _get_spend_report_for_time_range(
dependencies=[Depends(user_api_key_auth)],
responses={
200: {
"cost": {
"description": "The calculated cost",
"example": 0.0,
"type": "float",
}
"description": "The calculated cost",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"cost": {
"type": "number",
"description": "The calculated cost",
"example": 0.0,
}
},
}
}
},
}
},
)

View File

@ -292,21 +292,21 @@ class LiteLLMCompletionResponsesConfig:
)
_messages = litellm_completion_request.get("messages") or []
session_messages = chat_completion_session.get("messages") or []
# If session messages are empty (e.g., no database in test environment),
# we still need to process the new input messages
# Store original _messages before combining for safety check
original_new_messages = _messages.copy() if _messages else []
combined_messages = session_messages + _messages
# Fix: Ensure tool_results have corresponding tool_calls in previous assistant message
# Pass tools parameter to help reconstruct tool_calls if not in cache
tools = litellm_completion_request.get("tools") or []
combined_messages = LiteLLMCompletionResponsesConfig._ensure_tool_results_have_corresponding_tool_calls(
messages=combined_messages, tools=tools
)
# Safety check: Ensure we don't end up with empty messages
# This can happen when using previous_response_id without a database (e.g., in tests)
# and session messages are empty but new input messages exist
@ -340,7 +340,7 @@ class LiteLLMCompletionResponsesConfig:
"custom_llm_provider", ""
),
)
litellm_completion_request["messages"] = combined_messages
litellm_completion_request["litellm_trace_id"] = chat_completion_session.get(
"litellm_session_id"
@ -386,10 +386,45 @@ class LiteLLMCompletionResponsesConfig:
if call_id_raw:
existing_tool_call_ids.add(str(call_id_raw))
#########################################################
# Merge consecutive function_call items into a single assistant
# message. Anthropic requires that all tool_use blocks appear in
# ONE assistant message immediately followed by the tool_result
# blocks. Without this merging, each function_call creates its own
# assistant message, producing back-to-back assistant messages that
# Anthropic rejects with "tool_use ids were found without
# tool_result blocks immediately after".
#########################################################
if messages:
last_msg = messages[-1]
last_role = (
last_msg.get("role")
if isinstance(last_msg, dict)
else getattr(last_msg, "role", None)
)
if last_role == "assistant":
for new_msg in chat_completion_messages:
new_role = (
new_msg.get("role")
if isinstance(new_msg, dict)
else getattr(new_msg, "role", None)
)
if new_role == "assistant":
new_tcs = (
new_msg.get("tool_calls")
if isinstance(new_msg, dict)
else getattr(new_msg, "tool_calls", None)
) or []
for tc in new_tcs:
LiteLLMCompletionResponsesConfig._add_tool_call_to_assistant(
last_msg, tc
)
continue
#########################################################
# If Input Item is a Tool Call Output, add it to the tool_call_output_messages list
# preserving the ordering of tool call outputs. Some models require the tool
# result to immediately follow the assistant tool call.
# preserving the ordering of tool call outputs. Some models require the tool
# result to immediately follow the assistant tool call.
#########################################################
if LiteLLMCompletionResponsesConfig._is_input_item_tool_call_output(
input_item=_input
@ -774,14 +809,14 @@ class LiteLLMCompletionResponsesConfig:
]:
"""
Ensure that tool_result messages have corresponding tool_calls in the previous assistant message.
This is critical for Anthropic API which requires that each tool_result block has a
corresponding tool_use block in the previous assistant message.
Args:
messages: List of messages that may include tool_result messages
tools: Optional list of tools that can be used to reconstruct tool_calls if not in cache
Returns:
List of messages with tool_calls added to assistant messages when needed
"""
@ -801,18 +836,18 @@ class LiteLLMCompletionResponsesConfig:
]
] = list(copy.deepcopy(messages))
messages_to_remove = []
# Count non-tool messages to avoid removing all messages
# This prevents empty messages list when using previous_response_id without a database
non_tool_messages_count = sum(
1 for msg in fixed_messages if msg.get("role") != "tool"
)
for i, message in enumerate(fixed_messages):
# Only process tool messages - check role first to narrow the type
if message.get("role") != "tool":
continue
# At this point, we know it's a tool message, so it should have tool_call_id
# Use get() with default to safely access tool_call_id
tool_call_id_raw = (
@ -823,11 +858,11 @@ class LiteLLMCompletionResponsesConfig:
tool_call_id: str = (
str(tool_call_id_raw) if tool_call_id_raw is not None else ""
)
prev_assistant_idx = LiteLLMCompletionResponsesConfig._find_previous_assistant_idx(
fixed_messages, i
)
# Try to recover empty tool_call_id from previous assistant message
if not tool_call_id and prev_assistant_idx is not None:
prev_assistant = fixed_messages[prev_assistant_idx]
@ -842,7 +877,7 @@ class LiteLLMCompletionResponsesConfig:
message_dict["tool_call_id"] = tool_call_id
elif hasattr(message, "tool_call_id"):
setattr(message, "tool_call_id", tool_call_id)
# Only remove messages with empty tool_call_id if we have other non-tool messages
# This prevents ending up with an empty messages list when using previous_response_id
# without a database (e.g., in tests where session messages are empty)
@ -854,7 +889,7 @@ class LiteLLMCompletionResponsesConfig:
# If no non-tool messages, keep the tool message even with empty call_id
# The API will return a proper error message about the missing tool_use block
continue
# Check if the previous assistant message has the corresponding tool_call
# This needs to run for ALL tool messages with a valid tool_call_id,
# not just those that had an empty tool_call_id initially
@ -863,12 +898,12 @@ class LiteLLMCompletionResponsesConfig:
tool_calls = LiteLLMCompletionResponsesConfig._get_tool_calls_list(
prev_assistant
)
if not LiteLLMCompletionResponsesConfig._check_tool_call_exists(
tool_calls, tool_call_id
):
_tool_use_definition = TOOL_CALLS_CACHE.get_cache(key=tool_call_id)
if not _tool_use_definition and tools:
_tool_use_definition = LiteLLMCompletionResponsesConfig._reconstruct_tool_call_from_tools(
tool_call_id, tools
@ -891,11 +926,11 @@ class LiteLLMCompletionResponsesConfig:
LiteLLMCompletionResponsesConfig._add_tool_call_to_assistant(
prev_assistant, tool_call_chunk
)
# Remove messages with empty tool_call_id that couldn't be fixed
for idx in reversed(messages_to_remove):
fixed_messages.pop(idx)
return fixed_messages
@staticmethod
@ -1547,6 +1582,39 @@ class LiteLLMCompletionResponsesConfig:
return tool_call_dict
@staticmethod
def convert_apply_patch_tool_call_to_chat_completion_tool_call(
tool_call_item: Any,
index: int = 0,
) -> Dict[str, Any]:
"""
Convert ResponseApplyPatchToolCall to ChatCompletionToolCallChunk format.
The operation (create_file / update_file / delete_file) is serialised
as JSON so it appears in function.arguments, just like any other
tool call.
Args:
tool_call_item: ResponseApplyPatchToolCall object with call_id and operation
index: The index of this tool call
Returns:
Dictionary in ChatCompletionToolCallChunk format
"""
import json
operation_dict = tool_call_item.operation.model_dump()
tool_call_dict: Dict[str, Any] = {
"id": tool_call_item.call_id,
"function": {
"name": "apply_patch",
"arguments": json.dumps(operation_dict),
},
"type": "function",
"index": index,
}
return tool_call_dict
@staticmethod
def transform_chat_completion_response_to_responses_api_response(
request_input: Union[str, ResponseInputParam],

View File

@ -5505,6 +5505,10 @@ class Router:
return response
except Exception as e:
# Always track the latest error so we raise the most
# recent exception instead of the first one.
original_exception = e
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt - 1
@ -5519,6 +5523,24 @@ class Router:
)
else:
_healthy_deployments = []
# Check if this error is non-retryable (e.g., 400 context
# window exceeded). If so, raise immediately instead of
# continuing the retry loop. Respect retry policy
# precedence - only check when no retry policy applies.
if not _retry_policy_applies:
try:
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
)
except Exception:
raise e
_timeout = self._time_to_sleep_before_retry(
e=e,
remaining_retries=remaining_retries,

View File

@ -498,20 +498,22 @@ class LowestLatencyLoggingHandler(CustomLogger):
# get average latency or average ttft (depending on streaming/non-streaming)
total: float = 0.0
if (
use_ttft = (
request_kwargs is not None
and request_kwargs.get("stream", None) is not None
and request_kwargs["stream"] is True
and len(item_ttft_latency) > 0
):
)
if use_ttft:
for _call_latency in item_ttft_latency:
if isinstance(_call_latency, float):
total += _call_latency
item_latency = total / len(item_ttft_latency)
else:
for _call_latency in item_latency:
if isinstance(_call_latency, float):
total += _call_latency
item_latency = total / len(item_latency)
item_latency = total / len(item_latency)
# -------------- #
# Debugging Logic

View File

@ -458,6 +458,24 @@
"interactions": true
}
},
"charity_engine": {
"display_name": "Charity Engine (`charity_engine`)",
"url": "https://docs.litellm.ai/docs/providers/charity_engine",
"endpoints": {
"chat_completions": true,
"messages": true,
"responses": true,
"embeddings": false,
"image_generations": false,
"audio_transcriptions": false,
"audio_speech": false,
"moderations": false,
"batches": false,
"rerank": false,
"a2a": false,
"interactions": false
}
},
"chutes": {
"display_name": "Chutes (`chutes`)",
"endpoints": {

View File

@ -44,19 +44,25 @@ def create_skill_zip(skill_name: str, unique_suffix: Optional[str] = None):
skill_dir = test_dir / skill_name
# Create a zip file containing the skill directory
# When unique_suffix is set, folder name must match skill name in SKILL.md (Anthropic requirement)
zip_folder_name = f"{skill_name}-{unique_suffix}" if unique_suffix else skill_name
zip_path = test_dir / f"{skill_name}.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.write(skill_dir, arcname=skill_name)
if unique_suffix is not None:
# Rewrite SKILL.md with a unique name to avoid API conflicts
# Rewrite SKILL.md with a unique name and use matching folder name
skill_md = (skill_dir / "SKILL.md").read_text()
skill_md = skill_md.replace(
f"name: {skill_name}",
f"name: {skill_name}-{unique_suffix}",
f"name: {zip_folder_name}",
)
zf.writestr(f"{skill_name}/SKILL.md", skill_md)
zf.writestr(f"{zip_folder_name}/SKILL.md", skill_md)
# Add any other files in the skill dir (e.g. subdirs) under the new folder name
for f in skill_dir.rglob("*"):
if f.is_file() and f.name != "SKILL.md":
rel = f.relative_to(skill_dir)
zf.write(f, arcname=f"{zip_folder_name}/{rel}")
else:
zf.write(skill_dir, arcname=skill_name)
zf.write(skill_dir / "SKILL.md", arcname=f"{skill_name}/SKILL.md")
try:

View File

@ -1300,9 +1300,11 @@ def test_logging_async_cache_hit_sync_call(turn_off_message_logging):
"redacted-by-litellm"
== standard_logging_object["messages"][0]["content"]
)
assert {"text": "redacted-by-litellm"} == standard_logging_object[
"response"
]
# response is a full ModelResponse dict (choices format) since d84e5e381acf
assert (
standard_logging_object["response"]["choices"][0]["message"]["content"]
== "redacted-by-litellm"
)
def test_logging_standard_payload_failure_call():

View File

@ -45,7 +45,8 @@ async def test_global_redaction_on():
await asyncio.sleep(1)
standard_logging_payload = test_custom_logger.logged_standard_logging_payload
assert standard_logging_payload is not None
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
response = standard_logging_payload["response"]
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
assert standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
print(
"logged standard logging payload",
@ -75,7 +76,8 @@ async def test_global_redaction_with_dynamic_params(turn_off_message_logging):
)
if turn_off_message_logging is True:
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
response = standard_logging_payload["response"]
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
assert (
standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
)
@ -108,7 +110,8 @@ async def test_global_redaction_off_with_dynamic_params(turn_off_message_logging
json.dumps(standard_logging_payload, indent=2),
)
if turn_off_message_logging is True:
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
response = standard_logging_payload["response"]
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
assert (
standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
)
@ -390,7 +393,8 @@ async def test_redaction_with_streaming_response():
assert standard_logging_payload is not None
# Verify that redaction worked without pickle errors
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
response = standard_logging_payload["response"]
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
assert standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"
print(
"logged standard logging payload for streaming with coroutine handling",
@ -477,5 +481,6 @@ async def test_redaction_with_metadata_completion_api():
# Verify the helper function works correctly - with get_metadata_variable_name_from_kwargs,
# the system checks the appropriate field for headers
assert standard_logging_payload["response"] == {"text": "redacted-by-litellm"}
response = standard_logging_payload["response"]
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
assert standard_logging_payload["messages"][0]["content"] == "redacted-by-litellm"

View File

@ -1,9 +1,11 @@
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from litellm.caching.dual_cache import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.caching.redis_cache import RedisCache
@ -56,3 +58,104 @@ async def test_dual_cache_async_batch_get_cache_rolls_back_redis_reservation_on_
assert mock_async_batch_get_cache.call_count == 2
assert "shared_a" not in dual_cache.last_redis_batch_access_time
assert "shared_b" not in dual_cache.last_redis_batch_access_time
@pytest.mark.asyncio
async def test_dual_cache_async_set_cache_injects_default_in_memory_ttl():
"""
Test that async_set_cache injects default_in_memory_ttl into kwargs
when no explicit ttl is provided, matching the sync set_cache behavior.
Regression test for: async_set_cache was missing the TTL injection that
sync set_cache has, causing InMemoryCache to use its own default_ttl (600s)
instead of DualCache's default_in_memory_ttl.
"""
in_memory_cache = InMemoryCache(default_ttl=600)
dual_cache = DualCache(
in_memory_cache=in_memory_cache,
default_in_memory_ttl=60,
)
before = time.time()
await dual_cache.async_set_cache(key="test_key", value="test_value")
after = time.time()
# The TTL stored should reflect default_in_memory_ttl (60s), not
# InMemoryCache's default_ttl (600s)
expiry = in_memory_cache.ttl_dict["test_key"]
assert expiry >= before + 60
assert expiry <= after + 60
@pytest.mark.asyncio
async def test_dual_cache_async_set_cache_respects_explicit_ttl():
"""
Test that async_set_cache does NOT override an explicitly provided ttl.
"""
in_memory_cache = InMemoryCache(default_ttl=600)
dual_cache = DualCache(
in_memory_cache=in_memory_cache,
default_in_memory_ttl=60,
)
before = time.time()
await dual_cache.async_set_cache(key="test_key", value="test_value", ttl=30)
after = time.time()
# The explicit ttl=30 should be used, not default_in_memory_ttl (60)
expiry = in_memory_cache.ttl_dict["test_key"]
assert expiry >= before + 30
assert expiry <= after + 30
@pytest.mark.asyncio
async def test_dual_cache_async_set_cache_pipeline_injects_default_in_memory_ttl():
"""
Test that async_set_cache_pipeline injects default_in_memory_ttl into kwargs
when no explicit ttl is provided.
"""
in_memory_cache = InMemoryCache(default_ttl=600)
dual_cache = DualCache(
in_memory_cache=in_memory_cache,
default_in_memory_ttl=60,
)
cache_list = [("key_a", "value_a"), ("key_b", "value_b")]
before = time.time()
await dual_cache.async_set_cache_pipeline(cache_list=cache_list)
after = time.time()
for key in ["key_a", "key_b"]:
expiry = in_memory_cache.ttl_dict[key]
assert expiry >= before + 60
assert expiry <= after + 60
@pytest.mark.asyncio
async def test_dual_cache_sync_and_async_set_cache_use_same_ttl():
"""
Test that sync set_cache and async async_set_cache produce the same TTL
when no explicit ttl is provided, ensuring parity between the two paths.
"""
in_memory_sync = InMemoryCache(default_ttl=600)
dual_cache_sync = DualCache(
in_memory_cache=in_memory_sync,
default_in_memory_ttl=60,
)
in_memory_async = InMemoryCache(default_ttl=600)
dual_cache_async = DualCache(
in_memory_cache=in_memory_async,
default_in_memory_ttl=60,
)
dual_cache_sync.set_cache(key="test_key", value="test_value")
await dual_cache_async.async_set_cache(key="test_key", value="test_value")
sync_expiry = in_memory_sync.ttl_dict["test_key"]
async_expiry = in_memory_async.ttl_dict["test_key"]
# Both should use default_in_memory_ttl=60, so their expiry times
# should be within a small tolerance of each other
assert abs(sync_expiry - async_expiry) < 1.0

View File

@ -738,7 +738,58 @@ def test_response_completed_with_message_only_emits_stop_finish_reason():
)
def test_function_call_done_does_not_emit_finish_reason():
def test_response_completed_preserves_usage_with_cached_tokens():
"""
Test that response.completed correctly translates Responses API usage
(input_tokens_details) to chat completion usage (prompt_tokens_details).
This is a regression test for an issue where streaming with models that
use the Responses API bridge (e.g. gpt-5.2-codex) would drop
prompt_tokens_details, causing cached_tokens to always be None.
"""
from litellm.completion_extras.litellm_responses_transformation.transformation import (
OpenAiResponsesToChatCompletionStreamIterator,
)
iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True)
chunk = {
"type": "response.completed",
"response": {
"id": "resp_789",
"status": "completed",
"output": [
{
"type": "message",
"id": "msg_abc",
"role": "assistant",
"content": [{"type": "output_text", "text": "Six"}],
"status": "completed",
}
],
"usage": {
"input_tokens": 1226,
"output_tokens": 5,
"total_tokens": 1231,
"input_tokens_details": {"cached_tokens": 1024},
"output_tokens_details": {"reasoning_tokens": 0},
},
},
}
result = iterator.chunk_parser(chunk)
assert result.usage is not None, "usage should be set on response.completed chunk"
assert result.usage.prompt_tokens == 1226, "prompt_tokens should map from input_tokens"
assert result.usage.completion_tokens == 5, "completion_tokens should map from output_tokens"
assert result.usage.prompt_tokens_details is not None, "prompt_tokens_details should be set"
assert result.usage.prompt_tokens_details.cached_tokens == 1024, (
"cached_tokens should be preserved from input_tokens_details"
)
def test_function_call_done_emits_is_finished():
"""
Test that OUTPUT_ITEM_DONE for a function_call does NOT emit finish_reason.
The response.completed event handles the terminal finish_reason correctly.
@ -1327,6 +1378,138 @@ def test_transform_response_preserves_annotations():
print("✓ Annotations from Responses API are correctly preserved in Chat Completions format")
def test_apply_patch_tool_call_converted_to_chat_completion_tool_call():
"""
Test that ResponseApplyPatchToolCall items from the Responses API are
correctly converted to ChatCompletions-style tool calls by the bridge.
This is a regression test for a bug where litellm.completion() with a
responses/ model prefix crashed when the model returned an
apply_patch_call, because _convert_response_output_to_choices did not
handle ResponseApplyPatchToolCall items. The model DID use the tool,
but the bridge silently dropped it (or raised an error), while the
native litellm.responses() path worked correctly.
"""
import json
from unittest.mock import Mock
from openai.types.responses.response_apply_patch_tool_call import (
OperationCreateFile,
)
from openai.types.responses.response_output_item import (
ResponseApplyPatchToolCall,
)
from litellm.completion_extras.litellm_responses_transformation.transformation import (
LiteLLMResponsesTransformationHandler,
)
from litellm.types.llms.openai import (
InputTokensDetails,
OutputTokensDetails,
ResponseAPIUsage,
ResponsesAPIResponse,
)
from litellm.types.utils import ModelResponse, Usage
handler = LiteLLMResponsesTransformationHandler()
# Build an apply_patch_call item like the model would return
operation = OperationCreateFile(
diff="--- /dev/null\n+++ b/hello.py\n@@ -0,0 +1 @@\n+print('hello world')\n",
path="hello.py",
type="create_file",
)
apply_patch_item = ResponseApplyPatchToolCall(
id="apc_001",
call_id="call_patch_hello",
operation=operation,
status="completed",
type="apply_patch_call",
)
# Minimal usage
usage = ResponseAPIUsage(
input_tokens=30,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens=40,
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
total_tokens=70,
)
raw_response = ResponsesAPIResponse(
id="resp_apply_patch_test",
created_at=1234567890,
error=None,
incomplete_details=None,
instructions=None,
metadata={},
model="gpt-5.2-codex",
object="response",
output=[apply_patch_item],
parallel_tool_calls=True,
temperature=1.0,
tool_choice="auto",
tools=[],
top_p=1.0,
max_output_tokens=None,
previous_response_id=None,
reasoning=None,
status="completed",
text=None,
truncation="disabled",
usage=usage,
user=None,
store=True,
background=False,
)
model_response = ModelResponse(
id="chatcmpl-apply-patch",
created=1234567890,
model=None,
object="chat.completion",
choices=[],
usage=Usage(completion_tokens=0, prompt_tokens=0, total_tokens=0),
)
logging_obj = Mock()
result = handler.transform_response(
model="gpt-5.2-codex",
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data={"model": "gpt-5.2-codex"},
messages=[
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Create hello.py"},
],
optional_params={},
litellm_params={},
encoding=Mock(),
)
# Should have exactly one choice with finish_reason="tool_calls"
assert len(result.choices) == 1, f"Expected 1 choice, got {len(result.choices)}"
choice = result.choices[0]
assert choice.finish_reason == "tool_calls"
# The choice should contain one tool call for apply_patch
tool_calls = choice.message.tool_calls
assert tool_calls is not None, "tool_calls should not be None"
assert len(tool_calls) == 1, f"Expected 1 tool_call, got {len(tool_calls)}"
tc = tool_calls[0]
assert tc["id"] == "call_patch_hello"
assert tc["type"] == "function"
assert tc["function"]["name"] == "apply_patch"
# The operation should be serialised as JSON in arguments
args = json.loads(tc["function"]["arguments"])
assert args["type"] == "create_file"
assert args["path"] == "hello.py"
assert "print('hello world')" in args["diff"]
def test_multi_tool_call_stream_no_premature_finish():
"""
Regression test for multi-tool-call streaming bug.
@ -1778,3 +1961,35 @@ def test_parallel_tool_calls_comprehensive_streaming_integration():
)
print("✓ Parallel tool calls with split argument deltas stream correctly end-to-end")
def test_map_optional_params_preserves_reasoning_summary():
"""Test that reasoning_effort dict with summary field is preserved.
Regression test for: User reported that summary field was being dropped
when routing to Responses API. The dict format should be fully preserved.
"""
from litellm.completion_extras.litellm_responses_transformation.transformation import (
LiteLLMResponsesTransformationHandler,
)
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
handler = LiteLLMResponsesTransformationHandler()
optional_params = {
"stream": False,
"tools": [{"type": "function", "function": {"name": "test_tool"}}],
"tool_choice": "auto",
"reasoning_effort": {"effort": "high", "summary": "detailed"},
}
responses_api_request = ResponsesAPIOptionalRequestParams()
handler._map_optional_params_to_responses_api_request(
optional_params, responses_api_request
)
# Verify reasoning_effort dict with summary was fully preserved
assert "reasoning" in responses_api_request
assert responses_api_request["reasoning"] == {"effort": "high", "summary": "detailed"}
assert responses_api_request["reasoning"]["effort"] == "high"
assert responses_api_request["reasoning"]["summary"] == "detailed"

View File

@ -192,6 +192,23 @@ def test_azure_gpt5_1_series_temperature_handling(config: AzureOpenAIGPT5Config)
assert params["temperature"] == 0.6
def test_azure_gpt5_4_drops_reasoning_effort_when_tools_present(config: AzureOpenAIGPT5Config):
"""Azure Chat Completions: gpt-5.4+ drops reasoning_effort when tools are present.
OpenAI routes tools+reasoning to Responses API; Azure does not, so we drop reasoning_effort.
"""
tools = [{"type": "function", "function": {"name": "test", "description": "test"}}]
params = config.map_openai_params(
non_default_params={"reasoning_effort": "high", "tools": tools},
optional_params={},
model="gpt5_series/gpt-5.4",
drop_params=False,
api_version="2024-05-01-preview",
)
assert "reasoning_effort" not in params
assert params["tools"] == tools
def test_azure_gpt5_reasoning_effort_none_error(config: AzureOpenAIGPT5Config):
"""Test that Azure GPT-5 (non-5.1) raises error for reasoning_effort='none' when drop_params=False."""
with pytest.raises(litellm.utils.UnsupportedParamsError):

View File

@ -43,6 +43,29 @@ def test_transform_usage():
)
assert openai_usage._cache_creation_input_tokens == usage["cacheWriteInputTokens"]
assert openai_usage._cache_read_input_tokens == usage["cacheReadInputTokens"]
# completion_tokens_details should always be populated
assert openai_usage.completion_tokens_details is not None
assert openai_usage.completion_tokens_details.reasoning_tokens == 0
assert openai_usage.completion_tokens_details.text_tokens == usage["outputTokens"]
def test_transform_usage_with_reasoning_content():
"""Test that completion_tokens_details correctly tracks reasoning vs text tokens."""
usage = ConverseTokenUsageBlock(
**{
"inputTokens": 10,
"outputTokens": 100,
"totalTokens": 110,
}
)
config = AmazonConverseConfig()
reasoning_text = "Let me think about this step by step."
openai_usage = config._transform_usage(usage, reasoning_content=reasoning_text)
assert openai_usage.completion_tokens_details is not None
assert openai_usage.completion_tokens_details.reasoning_tokens > 0
assert openai_usage.completion_tokens_details.text_tokens == (
usage["outputTokens"] - openai_usage.completion_tokens_details.reasoning_tokens
)
def test_transform_system_message():
@ -3170,6 +3193,33 @@ def test_transform_request_with_output_config():
assert result["outputConfig"]["textFormat"]["structure"]["jsonSchema"]["name"] == "TestSchema"
def test_output_config_snake_case_stripped_from_bedrock_converse_request():
"""Test that output_config (snake_case) is stripped from Bedrock Converse requests.
Bedrock Converse API doesn't support the output_config parameter (Anthropic-only).
Nova and other Converse models reject requests with extraneous output_config.
"""
config = AmazonConverseConfig()
messages = [{"role": "user", "content": "test"}]
optional_params = {
"output_config": {"effort": "high"},
}
result = config._transform_request(
model="us.amazon.nova-pro-v1:0",
messages=messages,
optional_params=optional_params,
litellm_params={},
headers={},
)
# output_config must not appear in additionalModelRequestFields
additional = result.get("additionalModelRequestFields", {})
assert "output_config" not in additional, (
f"output_config should be stripped for Bedrock Converse, got: {list(additional.keys())}"
)
def test_transform_response_native_structured_output():
"""Test response handling when model returns JSON as text content (native structured output)."""
response_json = {

View File

@ -110,6 +110,60 @@ def test_get_supported_openai_params_reasoning_effort():
assert "reasoning_effort" not in unsupported_params
@pytest.mark.parametrize(
"api_base, expected_url_prefix",
[
(
"https://api.fireworks.ai/inference/v1",
"https://api.fireworks.ai/inference/v1/accounts/",
),
(
"https://api.fireworks.ai/inference/v1/",
"https://api.fireworks.ai/inference/v1/accounts/",
),
(
"https://custom-host.example.com/v1",
"https://custom-host.example.com/v1/accounts/",
),
(
"https://custom-host.example.com/api",
"https://custom-host.example.com/api/v1/accounts/",
),
],
ids=["default", "trailing-slash", "custom-with-v1", "custom-without-v1"],
)
def test_get_models_url_no_double_v1(api_base, expected_url_prefix):
"""Ensure get_models never produces a /v1/v1/ URL segment (fixes #23106)."""
config = FireworksAIConfig()
account_id = "fireworks"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"models": [{"name": "accounts/fireworks/models/llama-v3-70b"}]
}
with (
patch("litellm.module_level_client.get", return_value=mock_response) as mock_get,
patch(
"litellm.llms.fireworks_ai.chat.transformation.get_secret_str",
side_effect=lambda key: {
"FIREWORKS_API_KEY": "test-key",
"FIREWORKS_API_BASE": api_base,
"FIREWORKS_ACCOUNT_ID": account_id,
}.get(key),
),
):
result = config.get_models(api_key="test-key", api_base=api_base)
called_url = mock_get.call_args.kwargs.get("url") or mock_get.call_args[1].get("url", "")
assert "/v1/v1/" not in called_url, f"Double /v1/ detected in URL: {called_url}"
assert called_url.startswith(expected_url_prefix), (
f"URL {called_url} does not start with {expected_url_prefix}"
)
assert result == ["fireworks_ai/accounts/fireworks/models/llama-v3-70b"]
def test_transform_messages_helper_removes_provider_specific_fields():
"""
Test that _transform_messages_helper removes provider_specific_fields from messages.

View File

@ -13,6 +13,7 @@ from litellm.llms.openai.chat.gpt_transformation import (
OpenAIChatCompletionStreamingHandler,
OpenAIGPTConfig,
)
from litellm.llms.openai.chat.gpt_5_transformation import OpenAIGPT5Config
class TestOpenAIGPTConfig:
@ -324,3 +325,195 @@ class TestPromptCacheParams:
)
assert optional_params.get("prompt_cache_key") == "my-cache-key"
assert optional_params.get("prompt_cache_retention") == "24h"
class TestGPT5ReasoningEffortPreservation:
"""Tests for GPT-5 reasoning_effort dict preservation for Responses API."""
def setup_method(self):
self.config = OpenAIGPT5Config()
def test_reasoning_effort_string_preserved(self):
"""Test that reasoning_effort as string is preserved."""
non_default_params = {"reasoning_effort": "high"}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.4",
drop_params=False,
)
# String format should be preserved
assert non_default_params.get("reasoning_effort") == "high"
def test_reasoning_effort_dict_with_only_effort_normalized(self):
"""Test that reasoning_effort dict with only 'effort' key is normalized to string."""
non_default_params = {"reasoning_effort": {"effort": "high"}}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.4",
drop_params=False,
)
# Dict with only 'effort' should be normalized to string
assert non_default_params.get("reasoning_effort") == "high"
def test_reasoning_effort_dict_with_summary_preserved(self):
"""Test that reasoning_effort dict with 'summary' field is preserved for Responses API.
Regression test for: User reported that summary field was being dropped when
routing to Responses API. The dict format with additional fields should be
preserved so it can be properly handled by the Responses API transformation.
"""
non_default_params = {"reasoning_effort": {"effort": "high", "summary": "detailed"}}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.4",
drop_params=False,
)
# Dict with additional fields should be preserved
assert non_default_params.get("reasoning_effort") == {"effort": "high", "summary": "detailed"}
assert isinstance(non_default_params.get("reasoning_effort"), dict)
assert non_default_params["reasoning_effort"]["effort"] == "high"
assert non_default_params["reasoning_effort"]["summary"] == "detailed"
def test_reasoning_effort_dict_with_generate_summary_preserved(self):
"""Test that reasoning_effort dict with 'generate_summary' field is preserved."""
non_default_params = {"reasoning_effort": {"effort": "medium", "generate_summary": "auto"}}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.4",
drop_params=False,
)
# Dict with additional fields should be preserved
assert non_default_params.get("reasoning_effort") == {"effort": "medium", "generate_summary": "auto"}
assert isinstance(non_default_params.get("reasoning_effort"), dict)
def test_reasoning_effort_dict_with_all_fields_preserved(self):
"""Test that reasoning_effort dict with all fields is preserved."""
non_default_params = {
"reasoning_effort": {
"effort": "high",
"summary": "detailed",
"generate_summary": "concise"
}
}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.4",
drop_params=False,
)
# Dict with all fields should be preserved
reasoning = non_default_params.get("reasoning_effort")
assert isinstance(reasoning, dict)
assert reasoning["effort"] == "high"
assert reasoning["summary"] == "detailed"
assert reasoning["generate_summary"] == "concise"
def test_reasoning_effort_dict_xhigh_triggers_validation(self):
"""xhigh-dict: effective effort is extracted for model-support validation.
When reasoning_effort={"effort": "xhigh", "summary": "detailed"} is passed to a model
that doesn't support xhigh (e.g. gpt-5.1), the xhigh guard must fire.
"""
import litellm
non_default_params = {"reasoning_effort": {"effort": "xhigh", "summary": "detailed"}}
optional_params = {}
with pytest.raises(litellm.utils.UnsupportedParamsError):
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.1",
drop_params=False,
)
def test_reasoning_effort_dict_xhigh_dropped_when_requested(self):
"""xhigh-dict with drop_params=True: reasoning_effort is dropped."""
non_default_params = {"reasoning_effort": {"effort": "xhigh", "summary": "detailed"}}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.1",
drop_params=True,
)
assert "reasoning_effort" not in non_default_params
def test_reasoning_effort_dict_none_treated_as_none_for_tools(self):
"""none-dict: {"effort": "none", "summary": "detailed"} is treated as effort=none.
Tool-drop guard should NOT fire; reasoning_effort should be kept.
"""
tools = [{"type": "function", "function": {"name": "test", "description": "test"}}]
non_default_params = {"reasoning_effort": {"effort": "none", "summary": "detailed"}, "tools": tools}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.4",
drop_params=False,
)
assert non_default_params.get("reasoning_effort") == {"effort": "none", "summary": "detailed"}
assert non_default_params.get("tools") == tools
def test_reasoning_effort_dict_none_treated_as_none_for_sampling(self):
"""none-dict: {"effort": "none", "summary": "detailed"} allows logprobs/top_p.
Sampling-param guard should NOT fire; logprobs should be kept.
"""
non_default_params = {
"reasoning_effort": {"effort": "none", "summary": "detailed"},
"logprobs": True,
}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.1",
drop_params=False,
)
assert non_default_params.get("reasoning_effort") == {"effort": "none", "summary": "detailed"}
assert non_default_params.get("logprobs") is True
def test_reasoning_effort_dict_none_allows_temperature(self):
"""none-dict: {"effort": "none", "summary": "detailed"} allows non-default temperature."""
non_default_params = {
"reasoning_effort": {"effort": "none", "summary": "detailed"},
"temperature": 0.5,
}
optional_params = {}
self.config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model="gpt-5.1",
drop_params=False,
)
assert optional_params.get("temperature") == 0.5
assert non_default_params.get("reasoning_effort") == {"effort": "none", "summary": "detailed"}

View File

@ -324,10 +324,11 @@ def test_gpt5_4_pro_allows_reasoning_effort_xhigh(config: OpenAIConfig):
assert params["reasoning_effort"] == "xhigh"
def test_gpt5_normalizes_reasoning_effort_dict_to_string(config: OpenAIConfig):
"""Chat completion API expects reasoning_effort as a string, not a dict.
def test_gpt5_preserves_reasoning_effort_dict_with_summary(config: OpenAIConfig):
"""Dict with summary/generate_summary is preserved for Responses API.
Config/deployments may pass Responses API format: {'effort': 'high', 'summary': 'detailed'}.
We preserve the full dict so it reaches the Responses API transformation.
"""
params = config.map_openai_params(
non_default_params={"reasoning_effort": {"effort": "high", "summary": "detailed"}},
@ -335,18 +336,82 @@ def test_gpt5_normalizes_reasoning_effort_dict_to_string(config: OpenAIConfig):
model="gpt-5.4",
drop_params=False,
)
assert params["reasoning_effort"] == "high"
assert params["reasoning_effort"] == {"effort": "high", "summary": "detailed"}
def test_gpt5_normalizes_reasoning_effort_dict_from_optional_params(config: OpenAIConfig):
"""reasoning_effort dict in optional_params (e.g. from model config) is normalized."""
def test_gpt5_xhigh_dict_triggers_validation(config: OpenAIConfig):
"""Dict with effort='xhigh' triggers xhigh model-support validation.
Regression: when reasoning_effort is a dict, effective_effort must be used for
the xhigh guard so validation is not silently skipped.
"""
with pytest.raises(litellm.utils.UnsupportedParamsError):
config.map_openai_params(
non_default_params={"reasoning_effort": {"effort": "xhigh", "summary": "detailed"}},
optional_params={},
model="gpt-5.1",
drop_params=False,
)
def test_gpt5_xhigh_dict_accepted_for_supported_model(config: OpenAIConfig):
"""Dict with effort='xhigh' passes through for gpt-5.4+."""
params = config.map_openai_params(
non_default_params={"reasoning_effort": {"effort": "xhigh", "summary": "detailed"}},
optional_params={},
model="gpt-5.4",
drop_params=False,
)
assert params["reasoning_effort"] == {"effort": "xhigh", "summary": "detailed"}
def test_gpt5_none_dict_with_tools_no_tool_drop(config: OpenAIConfig):
"""Dict with effort='none' and tools: no tool-drop, reasoning_effort preserved.
Regression: effective_effort='none' must be used for tool-drop guard so
{"effort": "none", "summary": "detailed"} is not incorrectly treated as non-none.
"""
tools = [{"type": "function", "function": {"name": "test", "description": "test"}}]
params = config.map_openai_params(
non_default_params={"reasoning_effort": {"effort": "none", "summary": "detailed"}, "tools": tools},
optional_params={},
model="gpt-5.4",
drop_params=False,
)
assert params["reasoning_effort"] == {"effort": "none", "summary": "detailed"}
assert params["tools"] == tools
def test_gpt5_none_dict_with_sampling_params_allowed(config: OpenAIConfig):
"""Dict with effort='none' allows logprobs/top_p/top_logprobs.
Regression: effective_effort='none' must be used for sampling guard so
{"effort": "none", "summary": "detailed"} does not incorrectly trigger sampling errors.
"""
params = config.map_openai_params(
non_default_params={
"reasoning_effort": {"effort": "none", "summary": "detailed"},
"logprobs": True,
"top_p": 0.9,
},
optional_params={},
model="gpt-5.1",
drop_params=False,
)
assert params["reasoning_effort"] == {"effort": "none", "summary": "detailed"}
assert params["logprobs"] is True
assert params["top_p"] == 0.9
def test_gpt5_preserves_reasoning_effort_dict_with_summary_from_optional_params(config: OpenAIConfig):
"""reasoning_effort dict with summary in optional_params is preserved."""
params = config.map_openai_params(
non_default_params={},
optional_params={"reasoning_effort": {"effort": "medium", "summary": "detailed"}},
model="gpt-5.4",
drop_params=False,
)
assert params["reasoning_effort"] == "medium"
assert params["reasoning_effort"] == {"effort": "medium", "summary": "detailed"}
def test_gpt5_4_drops_reasoning_effort_when_tools_present(config: OpenAIConfig):

View File

@ -0,0 +1,243 @@
"""
Test cases for SageMaker embedding role assumption support
This module tests that the SageMaker embedding handler properly supports
AWS IAM role assumption via aws_role_name and aws_session_name parameters,
matching the behavior of the completion handler.
"""
import json
import os
import sys
from datetime import timezone
from unittest.mock import MagicMock, call, patch
sys.path.insert(0, os.path.abspath("../../../../.."))
from botocore.credentials import Credentials
from litellm.llms.sagemaker.completion.handler import SagemakerLLM
from litellm.types.utils import EmbeddingResponse
class TestSagemakerEmbeddingRoleAssumption:
"""Test that SageMaker embedding supports role assumption like completion does"""
def setup_method(self):
self.sagemaker_llm = SagemakerLLM()
def test_embedding_uses_load_credentials(self):
"""
Test that embedding() calls _load_credentials() to support role assumption.
This ensures aws_role_name and aws_session_name parameters are properly handled.
"""
# Mock credentials that would be returned after role assumption
mock_credentials = Credentials(
access_key="assumed-access-key",
secret_key="assumed-secret-key",
token="assumed-session-token",
)
# Mock the SageMaker client response
mock_sagemaker_client = MagicMock()
mock_sagemaker_client.invoke_endpoint.return_value = {
"Body": MagicMock(
read=MagicMock(return_value=json.dumps({"embedding": [[0.1, 0.2, 0.3]]}).encode())
)
}
# Mock boto3.Session to return our mock client
mock_session = MagicMock()
mock_session.client.return_value = mock_sagemaker_client
with patch.object(
self.sagemaker_llm, "_load_credentials", return_value=(mock_credentials, "us-east-1")
) as mock_load_creds, patch("boto3.Session", return_value=mock_session):
# Create mock logging object
mock_logging = MagicMock()
optional_params = {
"aws_role_name": "arn:aws:iam::123456789012:role/TestRole",
"aws_session_name": "test-session",
}
self.sagemaker_llm.embedding(
model="test-endpoint",
input=["hello world"],
model_response=EmbeddingResponse(),
print_verbose=print,
encoding=None,
logging_obj=mock_logging,
optional_params=optional_params,
)
# Verify _load_credentials was called with the optional_params
mock_load_creds.assert_called_once()
# Verify boto3.Session was created with the assumed credentials
mock_session_calls = mock_session.client.call_args_list
assert len(mock_session_calls) == 1
assert mock_session_calls[0] == call(service_name="sagemaker-runtime")
def test_embedding_role_assumption_with_sts(self):
"""
Test the full role assumption flow for embeddings, similar to completion.
Verifies that STS assume_role is called when aws_role_name is provided.
"""
# Mock the STS client for role assumption
mock_sts_client = MagicMock()
# Mock the STS response with proper expiration handling
mock_expiry = MagicMock()
mock_expiry.tzinfo = timezone.utc
time_diff = MagicMock()
time_diff.total_seconds.return_value = 3600
mock_expiry.__sub__ = MagicMock(return_value=time_diff)
mock_sts_response = {
"Credentials": {
"AccessKeyId": "assumed-access-key",
"SecretAccessKey": "assumed-secret-key",
"SessionToken": "assumed-session-token",
"Expiration": mock_expiry,
}
}
mock_sts_client.assume_role.return_value = mock_sts_response
# Mock the SageMaker client response
mock_sagemaker_client = MagicMock()
mock_sagemaker_client.invoke_endpoint.return_value = {
"Body": MagicMock(
read=MagicMock(return_value=json.dumps({"embedding": [[0.1, 0.2, 0.3]]}).encode())
)
}
# Mock boto3.Session for SageMaker client creation
mock_session = MagicMock()
mock_session.client.return_value = mock_sagemaker_client
def mock_boto3_client(service_name, **kwargs):
if service_name == "sts":
return mock_sts_client
return mock_sagemaker_client
with patch("boto3.client", side_effect=mock_boto3_client), \
patch("boto3.Session", return_value=mock_session):
mock_logging = MagicMock()
optional_params = {
"aws_role_name": "arn:aws:iam::123456789012:role/CrossAccountRole",
"aws_session_name": "litellm-embedding-session",
"aws_region_name": "us-east-1",
}
self.sagemaker_llm.embedding(
model="test-endpoint",
input=["hello world"],
model_response=EmbeddingResponse(),
print_verbose=print,
encoding=None,
logging_obj=mock_logging,
optional_params=optional_params,
)
# Verify STS assume_role was called with correct parameters
mock_sts_client.assume_role.assert_called_once()
call_args = mock_sts_client.assume_role.call_args
assert call_args[1]["RoleArn"] == "arn:aws:iam::123456789012:role/CrossAccountRole"
assert call_args[1]["RoleSessionName"] == "litellm-embedding-session"
def test_embedding_without_role_assumption(self):
"""
Test that embedding works without role assumption when aws_role_name is not provided.
Should use default credentials from environment/instance profile.
"""
# Mock the SageMaker client response
mock_sagemaker_client = MagicMock()
mock_sagemaker_client.invoke_endpoint.return_value = {
"Body": MagicMock(
read=MagicMock(return_value=json.dumps({"embedding": [[0.1, 0.2, 0.3]]}).encode())
)
}
mock_session = MagicMock()
mock_session.client.return_value = mock_sagemaker_client
# Mock credentials returned from environment
mock_credentials = Credentials(
access_key="env-access-key",
secret_key="env-secret-key",
token=None,
)
with patch.object(
self.sagemaker_llm, "_load_credentials", return_value=(mock_credentials, "us-west-2")
), patch("boto3.Session", return_value=mock_session):
mock_logging = MagicMock()
# No aws_role_name provided
optional_params = {
"aws_region_name": "us-west-2",
}
result = self.sagemaker_llm.embedding(
model="test-endpoint",
input=["hello world"],
model_response=EmbeddingResponse(),
print_verbose=print,
encoding=None,
logging_obj=mock_logging,
optional_params=optional_params,
)
# Should still work and return embeddings
assert result is not None
def test_embedding_session_created_with_assumed_credentials(self):
"""
Test that boto3.Session is created with the credentials from role assumption.
This verifies the credentials flow from _load_credentials to the SageMaker client.
"""
mock_credentials = Credentials(
access_key="assumed-key",
secret_key="assumed-secret",
token="assumed-token",
)
mock_sagemaker_client = MagicMock()
mock_sagemaker_client.invoke_endpoint.return_value = {
"Body": MagicMock(
read=MagicMock(return_value=json.dumps({"embedding": [[0.1, 0.2, 0.3]]}).encode())
)
}
with patch.object(
self.sagemaker_llm, "_load_credentials", return_value=(mock_credentials, "us-east-1")
), patch("boto3.Session") as mock_session_class:
mock_session = MagicMock()
mock_session.client.return_value = mock_sagemaker_client
mock_session_class.return_value = mock_session
mock_logging = MagicMock()
self.sagemaker_llm.embedding(
model="test-endpoint",
input=["hello world"],
model_response=EmbeddingResponse(),
print_verbose=print,
encoding=None,
logging_obj=mock_logging,
optional_params={},
)
# Verify Session was created with the assumed credentials
mock_session_class.assert_called_once_with(
aws_access_key_id="assumed-key",
aws_secret_access_key="assumed-secret",
aws_session_token="assumed-token",
region_name="us-east-1",
)

View File

@ -128,6 +128,75 @@ def test_vertex_ai_includes_labels():
def test_extra_body_cache_not_forwarded_to_vertex_ai():
"""
'cache' inside extra_body is a LiteLLM-internal proxy caching control.
It must NOT be forwarded to the Vertex AI request body.
Regression test for: "Invalid JSON payload received. Unknown name \"cache\": Cannot find field."
Vertex AI enforces a strict JSON schema and rejects any unknown field.
"""
messages = [{"role": "user", "content": "test"}]
optional_params = {
"extra_body": {
"cache": {"use-cache": True, "ttl": 86400}, # LiteLLM-internal
"some_vertex_param": "value", # legitimate provider extra
},
}
litellm_params = {}
result = _transform_request_body(
messages=messages,
model="gemini-2.5-pro",
optional_params=optional_params,
custom_llm_provider="vertex_ai",
litellm_params=litellm_params,
cached_content=None,
)
# 'cache' must be stripped — Vertex AI has no such field
assert "cache" not in result, (
"extra_body.cache must not be forwarded to Vertex AI. "
"Vertex AI rejects it with 400: Unknown name \"cache\": Cannot find field."
)
# Other legitimate extra_body keys should still pass through
assert "some_vertex_param" in result
assert result["some_vertex_param"] == "value"
# Core request fields must be present
assert "contents" in result
def test_extra_body_tags_not_forwarded_to_vertex_ai():
"""
'tags' inside extra_body is a LiteLLM-internal param for logging/tracking.
It must NOT be forwarded to the Vertex AI request body.
Documented in litellm_proxy.md: "Send tags by including them in the extra_body parameter"
"""
messages = [{"role": "user", "content": "test"}]
optional_params = {
"extra_body": {
"tags": ["user:alice", "env:prod"],
"custom_param": "allowed",
},
}
litellm_params = {}
result = _transform_request_body(
messages=messages,
model="gemini-2.5-pro",
optional_params=optional_params,
custom_llm_provider="vertex_ai",
litellm_params=litellm_params,
cached_content=None,
)
assert "tags" not in result
assert "custom_param" in result
assert result["custom_param"] == "allowed"
def test_metadata_to_labels_vertex_only():
"""Test that metadata->labels conversion only happens for Vertex AI"""
messages = [{"role": "user", "content": "test"}]

View File

@ -2093,3 +2093,150 @@ async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enab
assert spend_meta["tool_count_total"] == 1
assert spend_meta["allowed_server_count"] == 1
assert spend_meta["per_server_tool_counts"]["server_a"] == 1
def test_tool_name_matches_case_insensitive():
"""Test that _tool_name_matches performs case-insensitive comparison.
This is critical for OpenAPI-based MCP servers where:
1. operationIds are often in camelCase (e.g., 'addPet', 'updatePet')
2. Tool names are lowercased during registration (e.g., 'addpet', 'updatepet')
3. allowed_tools configuration may use the original camelCase names
Without case-insensitive matching, all tools would be filtered out.
"""
try:
from litellm.proxy._experimental.mcp_server.server import _tool_name_matches
except ImportError:
pytest.skip("MCP server not available")
# Test case 1: Unprefixed tool name with camelCase in filter list
assert _tool_name_matches("addpet", ["addPet", "updatePet"]) is True
assert _tool_name_matches("updatepet", ["addPet", "updatePet"]) is True
assert _tool_name_matches("deletepet", ["addPet", "updatePet"]) is False
# Test case 2: Prefixed tool name with camelCase in filter list
assert _tool_name_matches("per_store-addpet", ["addPet", "updatePet"]) is True
assert _tool_name_matches("per_store-updatepet", ["addPet", "updatePet"]) is True
assert _tool_name_matches("per_store-deletepet", ["addPet", "updatePet"]) is False
# Test case 3: Mixed case variations
assert _tool_name_matches("findPetsByStatus", ["findpetsbystatus"]) is True
assert _tool_name_matches("findpetsbystatus", ["findPetsByStatus"]) is True
assert _tool_name_matches("FINDPETSBYSTATUS", ["findPetsByStatus"]) is True
# Test case 4: Full prefixed name in filter list (case-insensitive)
assert _tool_name_matches("server-addPet", ["server-addpet"]) is True
assert _tool_name_matches("server-addpet", ["server-addPet"]) is True
# Test case 5: Ensure non-matching names still don't match
assert _tool_name_matches("addpet", ["deletePet", "updatePet"]) is False
assert _tool_name_matches("server-addpet", ["deletePet", "updatePet"]) is False
def test_filter_tools_by_allowed_tools_case_insensitive():
"""Test that filter_tools_by_allowed_tools handles case-insensitive matching.
Ensures that OpenAPI tools with lowercase names can be filtered using
camelCase allowed_tools configuration from the OpenAPI spec.
"""
try:
from litellm.proxy._experimental.mcp_server.server import (
filter_tools_by_allowed_tools,
)
from litellm.types.mcp_server.tool_registry import MCPTool
except ImportError:
pytest.skip("MCP server not available")
# Mock handler function
def mock_handler(**kwargs):
return kwargs
# Create mock tools with lowercase names (as registered from OpenAPI)
tools = [
MCPTool(
name="per_store-addpet",
description="Add a pet",
input_schema={"type": "object"},
handler=mock_handler,
),
MCPTool(
name="per_store-updatepet",
description="Update a pet",
input_schema={"type": "object"},
handler=mock_handler,
),
MCPTool(
name="per_store-deletepet",
description="Delete a pet",
input_schema={"type": "object"},
handler=mock_handler,
),
MCPTool(
name="per_store-findpetsbystatus",
description="Find pets by status",
input_schema={"type": "object"},
handler=mock_handler,
),
]
# Create mock server with camelCase allowed_tools (as from OpenAPI spec)
server = MCPServer(
server_id="test-server",
name="per_store",
transport=MCPTransport.http,
allowed_tools=["addPet", "updatePet", "findPetsByStatus"],
)
# Filter tools
filtered_tools = filter_tools_by_allowed_tools(tools, server)
# Should return 3 tools (case-insensitive match)
assert len(filtered_tools) == 3
assert any(t.name == "per_store-addpet" for t in filtered_tools)
assert any(t.name == "per_store-updatepet" for t in filtered_tools)
assert any(t.name == "per_store-findpetsbystatus" for t in filtered_tools)
assert not any(t.name == "per_store-deletepet" for t in filtered_tools)
def test_filter_tools_by_allowed_tools_no_filter():
"""Test that filter_tools_by_allowed_tools returns all tools when no filter is set."""
try:
from litellm.proxy._experimental.mcp_server.server import (
filter_tools_by_allowed_tools,
)
from litellm.types.mcp_server.tool_registry import MCPTool
except ImportError:
pytest.skip("MCP server not available")
# Mock handler function
def mock_handler(**kwargs):
return kwargs
tools = [
MCPTool(
name="fusion_litellm_mcp-model_list",
description="List models",
input_schema={"type": "object"},
handler=mock_handler,
),
MCPTool(
name="fusion_litellm_mcp-chat_completion",
description="Chat completion",
input_schema={"type": "object"},
handler=mock_handler,
),
]
# Server with no allowed_tools filter
server = MCPServer(
server_id="test-server",
name="fusion_litellm_mcp",
transport=MCPTransport.http,
allowed_tools=None,
)
filtered_tools = filter_tools_by_allowed_tools(tools, server)
# Should return all tools when no filter is configured
assert len(filtered_tools) == 2

View File

@ -21,6 +21,140 @@ def test_get_team_models_for_all_models_and_team_only_models():
assert set(result) == set(combined_models)
def test_get_team_models_all_proxy_models_includes_access_groups():
"""
When a team has 'all-proxy-models' and include_model_access_groups=True,
the result should include model access group names (e.g. 'claude-model-group')
in addition to individual model names.
"""
from litellm.proxy.auth.model_checks import get_team_models
team_models = ["all-proxy-models"]
proxy_model_list = ["model1", "model2"]
model_access_groups = {
"group-a": ["model1"],
"group-b": ["model2"],
}
result = get_team_models(
team_models, proxy_model_list, model_access_groups, include_model_access_groups=True
)
assert "group-a" in result
assert "group-b" in result
assert "model1" in result
assert "model2" in result
assert len(result) == len(set(result)), "result should have no duplicates"
def test_get_team_models_all_proxy_models_without_include_flag():
"""
When include_model_access_groups=False, access group names should NOT
appear in the result even with 'all-proxy-models'.
"""
from litellm.proxy.auth.model_checks import get_team_models
team_models = ["all-proxy-models"]
proxy_model_list = ["model1", "model2"]
model_access_groups = {
"group-a": ["model1"],
"group-b": ["model2"],
}
result = get_team_models(
team_models, proxy_model_list, model_access_groups, include_model_access_groups=False
)
assert "group-a" not in result
assert "group-b" not in result
assert "model1" in result
assert "model2" in result
def test_get_key_models_all_proxy_models_includes_access_groups():
"""
When a key has 'all-proxy-models' and include_model_access_groups=True,
the result should include model access group names.
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.model_checks import get_key_models
user_api_key_dict = UserAPIKeyAuth(
models=["all-proxy-models"],
api_key="test-key",
)
proxy_model_list = ["model1", "model2"]
model_access_groups = {
"group-a": ["model1"],
}
result = get_key_models(
user_api_key_dict=user_api_key_dict,
proxy_model_list=proxy_model_list,
model_access_groups=model_access_groups,
include_model_access_groups=True,
)
assert "group-a" in result
assert "model1" in result
assert "model2" in result
assert len(result) == len(set(result)), "result should have no duplicates"
def test_get_key_models_passes_include_model_access_groups():
"""
When a key explicitly has an access group name in its models list and
include_model_access_groups=True, the group name should be retained
(not stripped by _get_models_from_access_groups).
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.model_checks import get_key_models
user_api_key_dict = UserAPIKeyAuth(
models=["group-a"],
api_key="test-key",
)
proxy_model_list = ["model1", "model2"]
model_access_groups = {
"group-a": ["model1", "model2"],
}
result = get_key_models(
user_api_key_dict=user_api_key_dict,
proxy_model_list=proxy_model_list,
model_access_groups=model_access_groups,
include_model_access_groups=True,
)
assert "group-a" in result
assert "model1" in result
assert "model2" in result
def test_get_key_models_does_not_mutate_input():
"""
get_key_models must not mutate user_api_key_dict.models in-place.
_get_models_from_access_groups uses .pop()/.extend() which would corrupt
cached UserAPIKeyAuth objects if all_models were an alias instead of a copy.
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.model_checks import get_key_models
original_models = ["group-a", "extra-model"]
user_api_key_dict = UserAPIKeyAuth(
models=list(original_models), # give it a list
api_key="test-key",
)
model_access_groups = {
"group-a": ["model1", "model2"],
}
_ = get_key_models(
user_api_key_dict=user_api_key_dict,
proxy_model_list=["model1", "model2"],
model_access_groups=model_access_groups,
include_model_access_groups=False,
)
# The original models list on the auth object must be unchanged
assert user_api_key_dict.models == original_models
@pytest.mark.parametrize(
"key_models,team_models,proxy_model_list,model_list,expected",
[

View File

@ -2410,12 +2410,13 @@ def test_mapped_pass_through_routes_with_server_root_path():
)
@pytest.mark.asyncio
async def test_multipart_passthrough_preserves_boundary():
"""
Test that multipart/form-data requests through passthrough preserve the boundary
and can be correctly parsed by the upstream server.
Regression test for multipart boundary stripping issue.
"""
from io import BytesIO
@ -2426,41 +2427,41 @@ async def test_multipart_passthrough_preserves_boundary():
mock_response.headers = httpx.Headers({"content-type": "application/json"})
mock_response.aread = AsyncMock(return_value=b'{"filename": "test.txt", "size": 17}')
mock_response.text = '{"filename": "test.txt", "size": 17}'
async def mock_httpx_request(method, url, **kwargs):
# Verify that files parameter is passed (not json)
assert "files" in kwargs, "Files should be passed for multipart requests"
assert "file" in kwargs["files"], "File field should be in files dict"
# Verify content-type is NOT in headers (httpx will set it with correct boundary)
headers = kwargs.get("headers", {})
assert "content-type" not in headers, "content-type should be removed for multipart"
filename, content, content_type = kwargs["files"]["file"]
assert filename == "test.txt"
assert content == b"test file content"
assert content_type == "text/plain"
return mock_response
async_client = MagicMock()
async_client.request = AsyncMock(side_effect=mock_httpx_request)
# Create mock request
request = MagicMock(spec=Request)
request.method = "POST"
request.headers = Headers({"content-type": "multipart/form-data; boundary=test123"})
# Mock form data
file_content = b"test file content"
file = BytesIO(file_content)
headers = Headers({"content-type": "text/plain"})
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
upload_file.read = AsyncMock(return_value=file_content)
form_data = {"file": upload_file}
request.form = AsyncMock(return_value=form_data)
# Test the multipart handler directly
response = await HttpPassThroughEndpointHelpers.make_multipart_http_request(
request=request,
@ -2469,7 +2470,7 @@ async def test_multipart_passthrough_preserves_boundary():
headers={},
requested_query_params=None,
)
# Verify the response
assert response.status_code == 200
async_client.request.assert_called_once()

View File

@ -1071,9 +1071,10 @@ def test_spend_logs_redacts_request_and_response_when_turn_off_message_logging_e
response_result = _get_response_for_spend_logs_payload(payload=payload, kwargs=kwargs)
# When redaction is enabled and response is a dict (not ModelResponse),
# perform_redaction returns {"text": "redacted-by-litellm"}
# perform_redaction redacts content in-place within the choices structure
parsed_response = json.loads(response_result)
assert parsed_response == {"text": "redacted-by-litellm"}
assert parsed_response["choices"][0]["message"]["content"] == "redacted-by-litellm"
assert parsed_response["choices"][0]["message"]["role"] == "assistant"
@patch("litellm.secret_managers.main.get_secret_bool")

View File

@ -0,0 +1,142 @@
"""
Test that the OpenAPI schema generated by FastAPI is valid for specific endpoints.
Validates fixes for:
- /spend/calculate response schema (must use proper OpenAPI 3.x content wrapper)
- /credentials/by_model/{model_id} path parameter (must not leak credential_name)
Related issue: https://github.com/BerriAI/litellm/issues/21305
"""
import pytest
class TestSpendCalculateOpenAPISchema:
"""Test /spend/calculate response schema is valid OpenAPI 3.x."""
def test_response_schema_has_description(self):
"""The 200 response must have a 'description' field per OpenAPI 3.x spec."""
from litellm.proxy.spend_tracking.spend_management_endpoints import router
for route in router.routes:
if hasattr(route, "path") and route.path == "/spend/calculate":
responses = route.responses or {}
response_200 = responses.get(200, {})
assert "description" in response_200, (
"/spend/calculate 200 response must have a 'description' field"
)
break
else:
pytest.fail("/spend/calculate route not found in router")
def test_response_schema_has_content_wrapper(self):
"""The 200 response must use 'content' wrapper, not bare properties."""
from litellm.proxy.spend_tracking.spend_management_endpoints import router
for route in router.routes:
if hasattr(route, "path") and route.path == "/spend/calculate":
responses = route.responses or {}
response_200 = responses.get(200, {})
# Must NOT have 'cost' as a top-level key (invalid OpenAPI)
assert "cost" not in response_200, (
"/spend/calculate 200 response must not have 'cost' as a "
"top-level property - use 'content' wrapper instead"
)
# Must have 'content' wrapper
assert "content" in response_200, (
"/spend/calculate 200 response must have a 'content' field"
)
content = response_200["content"]
assert "application/json" in content
assert "schema" in content["application/json"]
break
else:
pytest.fail("/spend/calculate route not found in router")
class TestCredentialEndpointsOpenAPISchema:
"""Test /credentials endpoints have correct path parameters."""
def test_by_name_and_by_model_are_separate_handlers(self):
"""
/credentials/by_name/{credential_name} and /credentials/by_model/{model_id}
must be separate handler functions so each only declares its own path params.
"""
from litellm.proxy.credential_endpoints.endpoints import router
by_name_routes = []
by_model_routes = []
for route in router.routes:
if not hasattr(route, "path"):
continue
if "by_name" in route.path:
by_name_routes.append(route)
elif "by_model" in route.path:
by_model_routes.append(route)
assert len(by_name_routes) == 1, "Expected exactly one by_name route"
assert len(by_model_routes) == 1, "Expected exactly one by_model route"
# They must be different endpoint functions
by_name_endpoint = by_name_routes[0].endpoint
by_model_endpoint = by_model_routes[0].endpoint
assert by_name_endpoint is not by_model_endpoint, (
"by_name and by_model must be separate handler functions "
"to avoid path parameter conflicts in OpenAPI spec"
)
def test_by_model_route_does_not_require_credential_name(self):
"""
The /credentials/by_model/{model_id} route must NOT have
credential_name as a parameter.
"""
import inspect
from litellm.proxy.credential_endpoints.endpoints import (
get_credential_by_model,
)
sig = inspect.signature(get_credential_by_model)
param_names = list(sig.parameters.keys())
assert "credential_name" not in param_names, (
"get_credential_by_model must not have a credential_name parameter"
)
def test_by_name_route_does_not_require_model_id(self):
"""
The /credentials/by_name/{credential_name} route must NOT have
model_id as a parameter.
"""
import inspect
from litellm.proxy.credential_endpoints.endpoints import (
get_credential_by_name,
)
sig = inspect.signature(get_credential_by_name)
param_names = list(sig.parameters.keys())
assert "model_id" not in param_names, (
"get_credential_by_name must not have a model_id parameter"
)
def test_by_model_has_model_id_path_param(self):
"""The by_model handler must accept model_id as a path parameter."""
import inspect
from litellm.proxy.credential_endpoints.endpoints import (
get_credential_by_model,
)
sig = inspect.signature(get_credential_by_model)
assert "model_id" in sig.parameters, (
"get_credential_by_model must have a model_id parameter"
)
def test_by_name_has_credential_name_path_param(self):
"""The by_name handler must accept credential_name as a path parameter."""
import inspect
from litellm.proxy.credential_endpoints.endpoints import (
get_credential_by_name,
)
sig = inspect.signature(get_credential_by_name)
assert "credential_name" in sig.parameters, (
"get_credential_by_name must have a credential_name parameter"
)

View File

@ -1774,3 +1774,128 @@ class TestStreamingIDConsistency:
# Verify it matches the cached ID
assert iterator._cached_item_id is not None
assert iterator._cached_item_id == text_done_id
def test_parallel_tool_calls_merged_into_single_assistant_message(self):
"""
Regression test: multi-turn parallel tool calls via the Responses API must
produce a single assistant message with all tool_calls, not one assistant
message per function_call item.
When the model responds with two parallel tool calls (e.g. get_weather for
SF and NYC), the next Responses API request includes two consecutive
function_call items followed by two function_call_output items.
Without the fix each function_call becomes its own assistant message,
producing back-to-back assistant messages that Anthropic/Vertex AI rejects:
"tool_use ids were found without tool_result blocks immediately after".
"""
input_items = [
{"type": "message", "role": "user", "content": "Weather in SF and NYC?"},
# Two parallel tool calls from the previous assistant response
{
"type": "function_call",
"call_id": "toolu_01",
"name": "get_weather",
"arguments": '{"city": "SF"}',
},
{
"type": "function_call",
"call_id": "toolu_02",
"name": "get_weather",
"arguments": '{"city": "NYC"}',
},
# Tool results
{"type": "function_call_output", "call_id": "toolu_01", "output": "72°F"},
{"type": "function_call_output", "call_id": "toolu_02", "output": "55°F"},
]
messages = LiteLLMCompletionResponsesConfig._transform_response_input_param_to_chat_completion_message(
input=input_items
)
roles = [
m.get("role") if isinstance(m, dict) else getattr(m, "role", None)
for m in messages
]
# Must not have two consecutive assistant messages
for i in range(len(roles) - 1):
assert not (
roles[i] == "assistant" and roles[i + 1] == "assistant"
), f"Consecutive assistant messages at indices {i} and {i+1}: {roles}"
# The single assistant message must contain BOTH tool_calls
assistant_messages = [
m for m in messages
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None))
== "assistant"
]
assert len(assistant_messages) == 1, (
f"Expected 1 assistant message, got {len(assistant_messages)}"
)
assistant_msg = assistant_messages[0]
tool_calls = (
assistant_msg.get("tool_calls")
if isinstance(assistant_msg, dict)
else getattr(assistant_msg, "tool_calls", None)
)
assert tool_calls is not None and len(tool_calls) == 2, (
f"Expected 2 tool_calls in the merged assistant message, got: {tool_calls}"
)
call_ids = [
(tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None))
for tc in tool_calls
]
assert "toolu_01" in call_ids, f"toolu_01 missing from tool_calls: {call_ids}"
assert "toolu_02" in call_ids, f"toolu_02 missing from tool_calls: {call_ids}"
# Both tool messages must be present
tool_messages = [
m for m in messages
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None))
== "tool"
]
assert len(tool_messages) == 2, (
f"Expected 2 tool messages, got {len(tool_messages)}"
)
def test_single_tool_call_still_works_after_merge_fix(self):
"""
Ensure the parallel-tool-call merging fix does not break the existing
single-tool-call path.
"""
input_items = [
{"type": "message", "role": "user", "content": "Weather in SF?"},
{
"type": "function_call",
"call_id": "toolu_01",
"name": "get_weather",
"arguments": '{"city": "SF"}',
},
{"type": "function_call_output", "call_id": "toolu_01", "output": "72°F"},
]
messages = LiteLLMCompletionResponsesConfig._transform_response_input_param_to_chat_completion_message(
input=input_items
)
roles = [
m.get("role") if isinstance(m, dict) else getattr(m, "role", None)
for m in messages
]
assert "user" in roles
assert "assistant" in roles
assert "tool" in roles
assistant_messages = [m for m in messages if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "assistant"]
assert len(assistant_messages) == 1
tool_calls = (
assistant_messages[0].get("tool_calls")
if isinstance(assistant_messages[0], dict)
else getattr(assistant_messages[0], "tool_calls", None)
)
assert tool_calls is not None and len(tool_calls) == 1

View File

@ -0,0 +1,251 @@
"""
Test that the Router retry loop correctly handles non-retryable errors.
Verifies that:
1. Non-retryable errors (e.g., 400 ContextWindowExceeded) inside the retry loop
break out immediately instead of being swallowed.
2. original_exception is updated to the latest error, not stuck on the first.
3. Retryable errors (e.g., 429 RateLimitError) still retry normally.
Regression tests for https://github.com/BerriAI/litellm/issues/21343
"""
from unittest.mock import AsyncMock, patch
import pytest
import litellm
from litellm import Router
def _make_rate_limit_error(message="Rate limited"):
"""Create a RateLimitError for testing."""
return litellm.RateLimitError(
message=message,
llm_provider="bedrock",
model="anthropic.claude-v2",
)
def _make_context_window_error(message="prompt is too long: 1205821 tokens > 200000"):
"""Create a ContextWindowExceededError for testing."""
return litellm.ContextWindowExceededError(
message=message,
llm_provider="vertex_ai",
model="claude-3-opus",
)
def _make_bad_request_error(message="Invalid request"):
"""Create a BadRequestError for testing."""
return litellm.BadRequestError(
message=message,
llm_provider="openai",
model="gpt-4",
)
def _make_not_found_error(message="Model not found"):
"""Create a NotFoundError for testing."""
return litellm.NotFoundError(
message=message,
llm_provider="openai",
model="gpt-99",
)
def _create_router(num_retries=2):
"""Create a Router with two deployments for testing."""
return Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {
"model": "openai/gpt-4",
"api_key": "fake-key-1",
},
},
{
"model_name": "test-model",
"litellm_params": {
"model": "openai/gpt-4",
"api_key": "fake-key-2",
},
},
],
num_retries=num_retries,
)
def _base_kwargs():
"""Return kwargs required by async_function_with_retries."""
return {
"model": "test-model",
"messages": [{"role": "user", "content": "test"}],
"original_function": AsyncMock(),
"metadata": {},
}
@pytest.mark.asyncio
async def test_non_retryable_error_in_retry_loop_raises_immediately():
"""
When a non-retryable error (400 ContextWindowExceeded) occurs inside the
retry loop, the router should raise it immediately instead of swallowing it
and raising the original error.
Scenario: First call -> 429, Retry -> 400 (non-retryable)
Expected: ContextWindowExceededError is raised, NOT RateLimitError
"""
router = _create_router(num_retries=2)
rate_limit_error = _make_rate_limit_error()
context_window_error = _make_context_window_error()
call_count = 0
async def mock_make_call(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise rate_limit_error
else:
raise context_window_error
with patch.object(router, "make_call", side_effect=mock_make_call), \
patch.object(router, "_async_get_healthy_deployments",
return_value=(["d1", "d2"], ["d1", "d2"])), \
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
with pytest.raises(litellm.ContextWindowExceededError):
await router.async_function_with_retries(
num_retries=2,
**_base_kwargs(),
)
@pytest.mark.asyncio
async def test_bad_request_error_in_retry_loop_raises_immediately():
"""
A generic 400 BadRequestError inside the retry loop should also break out
immediately since 400 is not retryable.
"""
router = _create_router(num_retries=2)
rate_limit_error = _make_rate_limit_error()
bad_request_error = _make_bad_request_error()
call_count = 0
async def mock_make_call(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise rate_limit_error
else:
raise bad_request_error
with patch.object(router, "make_call", side_effect=mock_make_call), \
patch.object(router, "_async_get_healthy_deployments",
return_value=(["d1", "d2"], ["d1", "d2"])), \
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
with pytest.raises(litellm.BadRequestError):
await router.async_function_with_retries(
num_retries=2,
**_base_kwargs(),
)
@pytest.mark.asyncio
async def test_original_exception_updated_to_latest_error():
"""
When all retries are exhausted with retryable errors, the LAST error
should be raised, not the first one.
"""
router = _create_router(num_retries=2)
call_count = 0
async def mock_make_call(*args, **kwargs):
nonlocal call_count
call_count += 1
raise _make_rate_limit_error(f"Rate limit attempt {call_count}")
with patch.object(router, "make_call", side_effect=mock_make_call), \
patch.object(router, "_async_get_healthy_deployments",
return_value=(["d1", "d2"], ["d1", "d2"])), \
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
with pytest.raises(litellm.RateLimitError) as exc_info:
await router.async_function_with_retries(
num_retries=2,
**_base_kwargs(),
)
# Should be the LAST error, not the first
assert "Rate limit attempt 3" in str(exc_info.value)
@pytest.mark.asyncio
async def test_retryable_errors_still_retry_normally():
"""
Retryable errors (429 RateLimitError) should still be retried the
configured number of times before raising.
"""
router = _create_router(num_retries=3)
call_count = 0
async def mock_make_call(*args, **kwargs):
nonlocal call_count
call_count += 1
raise _make_rate_limit_error(f"Rate limit attempt {call_count}")
with patch.object(router, "make_call", side_effect=mock_make_call), \
patch.object(router, "_async_get_healthy_deployments",
return_value=(["d1", "d2"], ["d1", "d2"])), \
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
with pytest.raises(litellm.RateLimitError):
await router.async_function_with_retries(
num_retries=3,
**_base_kwargs(),
)
# Initial call + 3 retries = 4 total calls
assert call_count == 4
@pytest.mark.asyncio
async def test_not_found_error_in_retry_loop_raises_immediately():
"""
A 404 NotFoundError inside the retry loop should break out immediately.
"""
router = _create_router(num_retries=2)
rate_limit_error = _make_rate_limit_error()
not_found_error = _make_not_found_error()
call_count = 0
async def mock_make_call(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise rate_limit_error
else:
raise not_found_error
with patch.object(router, "make_call", side_effect=mock_make_call), \
patch.object(router, "_async_get_healthy_deployments",
return_value=(["d1", "d2"], ["d1", "d2"])), \
patch.object(router, "_time_to_sleep_before_retry", return_value=0), \
patch.object(router, "log_retry", side_effect=lambda kwargs, e: kwargs):
with pytest.raises(litellm.NotFoundError):
await router.async_function_with_retries(
num_retries=2,
**_base_kwargs(),
)
# Only 2 calls: initial + first retry that hits non-retryable
assert call_count == 2

View File

@ -275,8 +275,8 @@ it("should display user email correctly", async () => {
});
});
it("should show skeleton loaders when isLoading is true", () => {
// Mock loading state
it("should show loading message only on initial load (isPending)", () => {
// Mock initial loading state
mockUseKeys.mockReturnValue({
data: null,
isPending: true,
@ -296,7 +296,7 @@ it("should show skeleton loaders when isLoading is true", () => {
renderWithProviders(<VirtualKeysTable {...mockProps} />);
// Check that loading message is shown
// Check that loading message is shown on initial load
expect(screen.getByText("🚅 Loading keys...")).toBeInTheDocument();
// Check that actual key data is not shown
@ -810,3 +810,79 @@ describe("pagination display total count and page count", () => {
});
});
});
describe("refetch button", () => {
it("should show Fetch button in normal state", () => {
renderWithProviders(<VirtualKeysTable {...defaultMockProps} />);
const fetchButton = screen.getByTitle("Fetch data");
expect(fetchButton).toBeInTheDocument();
expect(fetchButton).not.toBeDisabled();
expect(screen.getByText("Fetch")).toBeInTheDocument();
});
it("should show Fetching state and keep table data visible during refetch", () => {
mockUseKeys.mockReturnValue({
data: {
keys: [mockKey],
total_count: 1,
current_page: 1,
total_pages: 1,
} as KeysResponse,
isPending: false,
isFetching: true,
refetch: vi.fn(),
} as any);
renderWithProviders(<VirtualKeysTable {...defaultMockProps} />);
// Button should show "Fetching" and be disabled
expect(screen.getByText("Fetching")).toBeInTheDocument();
const fetchButton = screen.getByTitle("Fetch data");
expect(fetchButton).toBeDisabled();
// Table data should still be visible (stale data)
expect(screen.getByText("Test Key Alias")).toBeInTheDocument();
// "Loading keys..." should NOT appear during refetch
expect(screen.queryByText("🚅 Loading keys...")).not.toBeInTheDocument();
});
it("should call refetch when Fetch button is clicked", () => {
const mockRefetch = vi.fn();
mockUseKeys.mockReturnValue({
data: {
keys: [mockKey],
total_count: 1,
current_page: 1,
total_pages: 1,
} as KeysResponse,
isPending: false,
isFetching: false,
refetch: mockRefetch,
} as any);
renderWithProviders(<VirtualKeysTable {...defaultMockProps} />);
const fetchButton = screen.getByTitle("Fetch data");
fireEvent.click(fetchButton);
expect(mockRefetch).toHaveBeenCalledTimes(1);
});
it("should show Fetch button enabled on error so user can retry", () => {
mockUseKeys.mockReturnValue({
data: null,
isPending: false,
isFetching: false,
isError: true,
refetch: vi.fn(),
} as any);
renderWithProviders(<VirtualKeysTable {...defaultMockProps} />);
const fetchButton = screen.getByTitle("Fetch data");
expect(fetchButton).not.toBeDisabled();
expect(screen.getByText("Fetch")).toBeInTheDocument();
});
});

View File

@ -85,6 +85,7 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
data: keys,
isPending: isLoading,
isFetching,
isError,
refetch,
} = useKeys(tablePagination.pageIndex + 1, tablePagination.pageSize, {
sortBy: sortBy || undefined,
@ -102,6 +103,15 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
organizations,
});
// Defer the transition so the button stays in loading state until the table
// has rendered with the new data (mirrors the spend-logs pattern)
const isFetchingDeferred = useDeferredValue(isFetching);
const isButtonLoading = (isFetching || isFetchingDeferred) && !isError;
const handleRefresh = () => {
refetch();
};
const totalCount = filteredTotalCount ?? keys?.total_count ?? 0;
// Add a useEffect to call refresh when a key is created
@ -669,16 +679,28 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
</div>
<div className="flex items-center justify-between w-full mb-4">
{isLoading || isFetching ? (
<Skeleton.Node active style={{ width: 200, height: 20 }} />
) : (
<span className="inline-flex text-sm text-gray-700">
Showing {rangeLabel} of {totalCount} results
</span>
)}
<div className="inline-flex items-center gap-2">
{isLoading ? (
<Skeleton.Node active style={{ width: 200, height: 20 }} />
) : (
<span className="inline-flex text-sm text-gray-700">
Showing {rangeLabel} of {totalCount} results
</span>
)}
<AntButton
type="default"
icon={<SyncOutlined spin={isButtonLoading} />}
onClick={handleRefresh}
disabled={isButtonLoading}
title="Fetch data"
>
{isButtonLoading ? "Fetching" : "Fetch"}
</AntButton>
</div>
<div className="inline-flex items-center gap-2">
{isLoading || isFetching ? (
{isLoading ? (
<Skeleton.Node active style={{ width: 74, height: 20 }} />
) : (
<span className="text-sm text-gray-700">
@ -686,24 +708,24 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
</span>
)}
{isLoading || isFetching ? (
{isLoading ? (
<Skeleton.Button active size="small" style={{ width: 84, height: 30 }} />
) : (
<button
onClick={() => table.previousPage()}
disabled={isLoading || isFetching || !table.getCanPreviousPage()}
disabled={isLoading || !table.getCanPreviousPage()}
className="px-3 py-1 text-sm border rounded-md hover:bg-gray-50 disabled:opacity-50 disabled:cursor-not-allowed"
>
Previous
</button>
)}
{isLoading || isFetching ? (
{isLoading ? (
<Skeleton.Button active size="small" style={{ width: 58, height: 30 }} />
) : (
<button
onClick={() => table.nextPage()}
disabled={isLoading || isFetching || !table.getCanNextPage()}
disabled={isLoading || !table.getCanNextPage()}
className="px-3 py-1 text-sm border rounded-md hover:bg-gray-50 disabled:opacity-50 disabled:cursor-not-allowed"
>
Next
@ -788,7 +810,7 @@ export function VirtualKeysTable({ teams, organizations, onSortChange, currentSo
))}
</TableHead>
<TableBody>
{isLoading || isFetching ? (
{isLoading ? (
<TableRow>
<TableCell colSpan={columns.length} className="h-8 text-center">
<div className="text-center text-gray-500">